Commit cf647057 by Leo

upload code

parent c42b8782
This source diff could not be displayed because it is too large. You can view the blob instead.
{
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "To9ENLU90WGl",
"outputId": "4b46c997-c16c-4141-eaf2-e7aa7da6d3a0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
}
},
"source": [
"!pip install transformers"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting transformers\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)\n",
"\r\u001b[K |█ | 10kB 15.1MB/s eta 0:00:01\r\u001b[K |██ | 20kB 2.2MB/s eta 0:00:01\r\u001b[K |███▏ | 30kB 3.2MB/s eta 0:00:01\r\u001b[K |████▏ | 40kB 2.1MB/s eta 0:00:01\r\u001b[K |█████▎ | 51kB 2.6MB/s eta 0:00:01\r\u001b[K |██████▎ | 61kB 3.1MB/s eta 0:00:01\r\u001b[K |███████▍ | 71kB 3.6MB/s eta 0:00:01\r\u001b[K |████████▍ | 81kB 4.1MB/s eta 0:00:01\r\u001b[K |█████████▌ | 92kB 4.5MB/s eta 0:00:01\r\u001b[K |██████████▌ | 102kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████▋ | 112kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████▋ | 122kB 3.5MB/s eta 0:00:01\r\u001b[K |█████████████▊ | 133kB 3.5MB/s eta 0:00:01\r\u001b[K |██████████████▊ | 143kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 153kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 163kB 3.5MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 174kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████████████ | 184kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████████████ | 194kB 3.5MB/s eta 0:00:01\r\u001b[K |█████████████████████ | 204kB 3.5MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 215kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 225kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 235kB 3.5MB/s eta 0:00:01\r\u001b[K |█████████████████████████▎ | 245kB 3.5MB/s eta 0:00:01\r\u001b[K |██████████████████████████▎ | 256kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 266kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████████████████████▍ | 276kB 3.5MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▍ | 286kB 3.5MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▌ | 296kB 3.5MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 307kB 3.5MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 317kB 3.5MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.17.4)\n",
"Collecting sentencepiece\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/14/3d/efb655a670b98f62ec32d66954e1109f403db4d937c50d779a75b9763a29/sentencepiece-0.1.83-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
"\u001b[K |████████████████████████████████| 1.0MB 53.8MB/s \n",
"\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.10.14)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers) (4.28.1)\n",
"Collecting regex\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/e3/8e/cbf2295643d7265e7883326fb4654e643bfc93b3a8a8274d8010a39d8804/regex-2019.11.1-cp36-cp36m-manylinux1_x86_64.whl (643kB)\n",
"\u001b[K |████████████████████████████████| 645kB 39.9MB/s \n",
"\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
"Collecting sacremoses\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/1f/8e/ed5364a06a9ba720fddd9820155cc57300d28f5f43a6fd7b7e817177e642/sacremoses-0.0.35.tar.gz (859kB)\n",
"\u001b[K |████████████████████████████████| 860kB 48.8MB/s \n",
"\u001b[?25hRequirement already satisfied: botocore<1.14.0,>=1.13.14 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.13.14)\n",
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.2.1)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.4)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.9.11)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.0)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.0)\n",
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.14->boto3->transformers) (0.15.2)\n",
"Requirement already satisfied: python-dateutil<2.8.1,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.14->boto3->transformers) (2.6.1)\n",
"Building wheels for collected packages: sacremoses\n",
" Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for sacremoses: filename=sacremoses-0.0.35-cp36-none-any.whl size=883999 sha256=cb6e175a99d8f14f69f593b734ea73303c23a7dbdc694be65435cfeebd1f3124\n",
" Stored in directory: /root/.cache/pip/wheels/63/2a/db/63e2909042c634ef551d0d9ac825b2b0b32dede4a6d87ddc94\n",
"Successfully built sacremoses\n",
"Installing collected packages: sentencepiece, regex, sacremoses, transformers\n",
"Successfully installed regex-2019.11.1 sacremoses-0.0.35 sentencepiece-0.1.83 transformers-2.1.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fvFvBLJV0Dkv",
"outputId": "140119e5-4cee-4604-c0d2-be279c18b125",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 63
}
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.model_selection import cross_val_score\n",
"import torch\n",
"import transformers as ppb\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cyoj29J24hPX"
},
"source": [
"df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\\t', header=None)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gTM3hOHW4hUY"
},
"source": [
"batch_1 = df[:2000]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PRc2L89hh1Tf"
},
"source": [
"- 查看数据中被标记为“积极”和“消极”的句子的个数"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jGvcfcCP5xpZ",
"outputId": "4c4a8afc-1035-4b21-ba9a-c4bb6cfc6347",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"batch_1[1].value_counts()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1 1041\n",
"0 959\n",
"Name: 1, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7_MO08_KiAOb"
},
"source": [
"## Loading the Pre-trained BERT model\n",
"- 加载预训练的Bert模型"
]
},
{
"cell_type": "code",
"metadata": {
"id": "q1InADgf5xm2",
"outputId": "dbc52856-4d52-42f8-8a74-a89944280a02",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"# For DistilBERT:\n",
"model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')\n",
"\n",
"## 如果你希望使用Bert执行代码则将下一行代码反注释\n",
"#model_class, tokenizer_class, pretrained_weights = (ppb.BertModel, ppb.BertTokenizer, 'bert-base-uncased')\n",
"\n",
"# Load pretrained model/tokenizer\n",
"tokenizer = tokenizer_class.from_pretrained(pretrained_weights)\n",
"model = model_class.from_pretrained(pretrained_weights)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 231508/231508 [00:00<00:00, 2649246.11B/s]\n",
"100%|██████████| 492/492 [00:00<00:00, 284634.15B/s]\n",
"100%|██████████| 267967963/267967963 [00:03<00:00, 72728701.55B/s]\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Dg82ndBA5xlN"
},
"source": [
"tokenized = batch_1[0].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "URn-DWJt5xhP"
},
"source": [
"max_len = 0\n",
"for i in tokenized.values:\n",
" if len(i) > max_len:\n",
" max_len = len(i)\n",
"\n",
"padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jdi7uXo95xeq",
"outputId": "be786022-e84f-4e28-8531-0143af2347bc",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"np.array(padded).shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2000, 59)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sDZBsYSDjzDV"
},
"source": [
"### Masking\n",
"- 当存在padding操作时,会出现空数据,亦即无效数据,此时就需要使用Masking对补充数据进行掩码操作"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4K_iGRNa_Ozc",
"outputId": "d03b0a9b-1f6e-4e32-831e-b04f5389e57c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"attention_mask = np.where(padded != 0, 1, 0)\n",
"attention_mask.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2000, 59)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "39UVjAV56PJz"
},
"source": [
"input_ids = torch.tensor(padded)\n",
"attention_mask = torch.tensor(attention_mask)\n",
"\n",
"with torch.no_grad():\n",
" last_hidden_states = model(input_ids, attention_mask=attention_mask)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "C9t60At16PVs"
},
"source": [
"features = last_hidden_states[0][:,0,:].numpy() # 使用CLS标签作为输出"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JD3fX2yh6PTx"
},
"source": [
"labels = batch_1[1]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iaoEvM2evRx1"
},
"source": [
"## Model #2: Train/Test Split\n",
"- 拆分训练数据"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ddAqbkoU6PP9"
},
"source": [
"train_features, test_features, train_labels, test_labels = train_test_split(features, labels)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gG-EVWx4CzBc",
"outputId": "9252ceff-a7d0-4359-fef9-2f72be89c7d6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
}
},
"source": [
"lr_clf = LogisticRegression()\n",
"lr_clf.fit(train_features, train_labels)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
" intercept_scaling=1, l1_ratio=None, max_iter=100,\n",
" multi_class='warn', n_jobs=None, penalty='l2',\n",
" random_state=None, solver='warn', tol=0.0001, verbose=0,\n",
" warm_start=False)"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3rUMKuVgwzkY"
},
"source": [
"## Evaluating Model #2\n",
"\n",
"- 评估模型"
]
},
{
"cell_type": "code",
"metadata": {
"id": "iCoyxRJ7ECTA",
"outputId": "cfd86dea-5d16-476c-ab9b-47cbee3a014f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"lr_clf.score(test_features, test_labels)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.824"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "75oyhr3VxHoE"
},
"source": [
"- 最简单的评估比较方式:将结果与dummy-classify进行比对"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lnwgmqNG7i5l",
"outputId": "0042aed2-4fa8-4fa0-bf25-fdef70a10aac",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"from sklearn.dummy import DummyClassifier\n",
"clf = DummyClassifier()\n",
"\n",
"scores = cross_val_score(clf, train_features, train_labels)\n",
"print(\"Dummy classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Dummy classifier score: 0.527 (+/- 0.05)\n"
],
"name": "stdout"
}
]
}
]
}
\ No newline at end of file
import numpy as np
import numpy as np
def get_positional_encoding(max_len, d_model):
positional_encoding = np.zeros((max_len, d_model))
for pos in range(max_len):
for i in range(0, d_model, 2):
positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))
if i + 1 < d_model:
positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * i) / d_model)))
return positional_encoding
{
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"mount_file_id": "1dnzq6rCBj8s3iRHRvUCK121ffef6jvJw",
"authorship_tag": "ABX9TyNHwL5IYuVsZNX9ms9OJVuq",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ymoslem/Adaptive-MT-LLM-Fine-tuning/blob/main/ChatGPT-Adaptive-MT.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Batch Translation with ChatGPT\n",
"\n",
"This notebook is part of the repository [Adaptive-MT-LLM-Fine-tuning](https://github.com/ymoslem/Adaptive-MT-LLM-Fine-tuning)."
],
"metadata": {
"id": "ta3i3wddYId3"
}
},
{
"cell_type": "markdown",
"source": [
"# Load files"
],
"metadata": {
"id": "2yFOBmnFLUEN"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"\n",
"data_path = \"/content/drive/MyDrive/data/\"\n",
"directory = os.path.join(data_path, \"spanish\")\n",
"\n",
"os.chdir(directory)\n",
"os.getcwd()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "thux1ESjCF9H",
"outputId": "859152a6-db31-4feb-b2d4-0048a1a7fe67"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'/content/drive/MyDrive/data/spanish'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"# Load test dataset\n",
"\n",
"source_test_file = \"all-filtered.es.real.test\"\n",
"target_test_file = \"all-filtered.en.real.test\"\n",
"\n",
"with open(source_test_file, encoding=\"utf-8\") as source, open(target_test_file, encoding=\"utf-8\") as target:\n",
" source_sentences = [sent.strip() for sent in source.readlines()]\n",
" target_sentences = [sent.strip() for sent in target.readlines()]\n",
"\n",
"print(source_sentences[0])\n",
"print(target_sentences[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XdQA6lbBCI5M",
"outputId": "cbd2fbab-b90d-42c9-def7-8e8742174e58"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Período de validez después de abierto el envase: 10 horas.\n",
"Shelf life after first opening the container: 10 hours.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Load fuzzy matches from the Context Dataset\n",
"\n",
"online_test_file = \"all-filtered.esen.ms-multi-12.online.test\"\n",
"\n",
"with open(online_test_file, encoding=\"utf-8\") as online:\n",
" lines = [line.strip().split(\" ||| \") for line in online.readlines()]\n",
" scores = [float(line[0].strip()) for line in lines]\n",
" fuzzy_source_sentences = [line[1].strip() for line in lines]\n",
" online_source_sentences = [line[2].strip() for line in lines]\n",
" fuzzy_target_prefixes = [line[3].strip() for line in lines]\n",
"\n",
"print(fuzzy_source_sentences[0])\n",
"print(online_source_sentences[0])\n",
"print(fuzzy_target_prefixes[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BMB8lQLCCQ2g",
"outputId": "5435b351-35c8-44b5-faa3-fa3aa59a132f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Período de validez después de abierto el envase: 4 semanas\n",
"Período de validez después de abierto el envase: 10 horas.\n",
"Shelf life after opening the immediate packaging: 4 weeks.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Create prompts"
],
"metadata": {
"id": "DYb9dxsAQtpt"
}
},
{
"cell_type": "code",
"source": [
"# Function to create zero-shot and one-shot prompts\n",
"\n",
"def create_prompt(source_lang,\n",
" target_lang,\n",
" fuzzy_sources,\n",
" fuzzy_targets,\n",
" new_sources,\n",
" one_shot=True\n",
" ):\n",
"\n",
" prompts = []\n",
"\n",
" if one_shot:\n",
" for fuzzy_src, fuzzy_tgt, new_src in zip(fuzzy_sources, fuzzy_targets, new_sources):\n",
" fuzzy_src = source_lang + \": \" + fuzzy_src\n",
" fuzzy_tgt = target_lang + \": \" + fuzzy_tgt\n",
" new_src = source_lang + \": \" + new_src\n",
" segment = fuzzy_src + \"\\n\" + fuzzy_tgt + \"\\n\" + new_src + \"\\n\" + target_lang + \":\"\n",
" prompts.append(segment)\n",
" else:\n",
" for new_src in new_sources:\n",
" new_src = source_lang + \": \" + new_src\n",
" segment = new_src + \"\\n\" + target_lang + \":\"\n",
" prompts.append(segment)\n",
"\n",
" return prompts"
],
"metadata": {
"id": "lhtPPU5zDFp0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"source_lang = \"Spanish\"\n",
"target_lang = \"English\""
],
"metadata": {
"id": "l_XXEc-cQvno"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Create prompts\n",
"# Set one_shot=True to create a one-shot prompts\n",
"\n",
"prompts = create_prompt(source_lang,\n",
" target_lang,\n",
" fuzzy_source_sentences,\n",
" fuzzy_target_prefixes,\n",
" online_source_sentences,\n",
" one_shot=False\n",
" )\n",
"\n",
"print(len(prompts))"
],
"metadata": {
"id": "FVTTpMt8QxKX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "27cbd567-1158-491f-e4c8-219be9e3e4ec"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(prompts[0], \"\\n\")\n",
"print(prompts[-1])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HAlrExMiEDhk",
"outputId": "95e17dc4-4585-48a8-b307-404f7840ae15"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Spanish: Período de validez después de abierto el envase: 10 horas.\n",
"English: \n",
"\n",
"Spanish: El mecanismo implicado en esta posible asociación es aún especulativo pero puede reflejar la mayor frecuencia en mujeres por la disfunción del esfínter de Oddi como lo señalado por Freeman y cols en su estudio 2.\n",
"English:\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Generation"
],
"metadata": {
"id": "FXV2gNR-LWu1"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HeVV8GubKcHV"
},
"outputs": [],
"source": [
"!pip3 install openai --upgrade -q"
]
},
{
"cell_type": "code",
"source": [
"# Get OpenAI API key from Colab Secrets\n",
"\n",
"from google.colab import userdata\n",
"OPENAI_API_KEY = userdata.get(\"openai_api_key\")"
],
"metadata": {
"id": "o_216BV0q3mB"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# ChatGPT generation function\n",
"# model: You can change \"gpt-3.5-turbo\" to \"gpt-4\", but for higher costs!\n",
"\n",
"import openai\n",
"from tenacity import retry, stop_after_attempt, wait_random_exponential\n",
"\n",
"\n",
"# ✳️ Add your OpenAI API key\n",
"openai.api_key = OPENAI_API_KEY\n",
"\n",
"@retry(wait=wait_random_exponential(min=2, max=60), stop=stop_after_attempt(6))\n",
"def translate(prompt, max_tokens, model, temperature=0.3, top_p=1):\n",
" response = openai.chat.completions.create(\n",
" model=model,\n",
" temperature=temperature,\n",
" max_tokens=max_tokens,\n",
" messages=[\n",
" {\"role\": \"user\",\n",
" \"content\": prompt}\n",
" ],\n",
" top_p=top_p,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" n=1,\n",
" #stop=[\"\\n\"],\n",
" )\n",
"\n",
" return response"
],
"metadata": {
"id": "5nm0YbcxLIB0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test\n",
"\n",
"test_translation = translate(prompt=prompts[0], max_tokens=100, model=\"gpt-3.5-turbo-1106\")\n",
"print(test_translation.choices[0].message.content.strip())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bgr8djzsL6iO",
"outputId": "fc7fe039-1c86-4273-8efa-4262e8ce3f12"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Shelf life after opening the package: 10 hours.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Batch Processing"
],
"metadata": {
"id": "icn7U4ODL1Pw"
}
},
{
"cell_type": "code",
"source": [
"# Sending batch requsets\n",
"\n",
"from concurrent import futures\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"\n",
"num_workers = 128\n",
"\n",
"def batch_translate(prompts, **kwargs):\n",
" with futures.ThreadPoolExecutor(max_workers=num_workers) as executor:\n",
" response = executor.map(lambda prompt: translate(prompt=prompt, **kwargs), prompts)\n",
" return list(response)"
],
"metadata": {
"id": "iIXH8GVHLb7g"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Devide a long list of source sentences into smaller batches\n",
"\n",
"def divide_chunks(l, n):\n",
" # looping till length l\n",
" for i in range(0, len(l), n):\n",
" yield l[i:i + n]"
],
"metadata": {
"id": "BM-0YkZuR0ce"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Parameters\n",
"\n",
"temperature = 0.3\n",
"top_p = 1\n",
"\n",
"# ✳️ Change the batch size for longer inputs/outputs\n",
"# Note: Trial accounts allow only 3 requests per minute\n",
"batch_size = 20\n",
"\n",
"# ✳️ Change number of source words vs target tokens.\n",
"# Try 4 for French and Spanish; it can be 5 for some other languages like Arabic.\n",
"# You can also use the \"tiktoken\" library to tokenize the source,\n",
"# and then length_multiplier can be based on tokens rather than words.\n",
"length_multiplier = 4"
],
"metadata": {
"id": "qjpu8ANWfFIk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Model name\n",
"\n",
"model = \"gpt-3.5-turbo\"\n",
"\n",
"# Other models\n",
"# model = \"gpt-3.5-turbo-1106\"\n",
"# model = \"gpt-4\"\n",
"# model = \"gpt-4-1106-preview\" # GPT-4 TurboNew"
],
"metadata": {
"id": "q0WoAPhJo51U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Batch translation\n",
"\n",
"from tqdm.notebook import tqdm\n",
"from time import sleep\n",
"import json\n",
"\n",
"start = 2380 # change to 0\n",
"\n",
"# Translate\n",
"translations = []\n",
"total = int(len(prompts[start:])/batch_size)\n",
"\n",
"\n",
"with open(\"temp_output.json\", \"a\") as output_file:\n",
"\n",
" for chunk_prompts in tqdm(divide_chunks(prompts[start:], batch_size), total=total):\n",
" length = [len(prompt.split(\"\\n\")[-2].split(\" \")[1:]) for prompt in chunk_prompts]\n",
" max_len = max(length) * length_multiplier\n",
"\n",
" outputs = batch_translate(prompts = chunk_prompts,\n",
" max_tokens = max_len,\n",
" model = model,\n",
" temperature=temperature,\n",
" top_p = top_p)\n",
" batch_translations = [output.choices[0].message.content.strip() for output in outputs]\n",
" translations += batch_translations\n",
"\n",
" output_translations = [{\"translation\": translation.strip()} for translation in batch_translations]\n",
" output_translations = \"\\n\".join([json.dumps(translation, ensure_ascii=False) for translation in output_translations])\n",
" # Write raw translations to a JSON file (without handling over-generation)\n",
" output_file.write(output_translations + \"\\n\")\n",
" output_file.flush()\n",
"\n",
" sleep(10)\n",
"\n",
"\n",
"# Report stats\n",
"print(\"Translations:\", len(translations), end=\"\\n\\n\")\n",
"print(\"• Last Translation:\")\n",
"print(\"Prompt Tokens:\", outputs[-1].usage.prompt_tokens)\n",
"print(\"Completion Tokens:\", outputs[-1].usage.completion_tokens)\n",
"print(\"Total Tokens:\", outputs[-1].usage.total_tokens, end=\"\\n\\n\")\n",
"print(prompts[-1], end=\" \")\n",
"print(translations[-1], sep=\"\\n\")"
],
"metadata": {
"id": "Y2tbPOetSByU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"len(translations)"
],
"metadata": {
"id": "h_gfH2BWOxi8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the first 5 translations\n",
"print(*translations[:5], sep=\"\\n\")"
],
"metadata": {
"id": "rg4K3jvGd7SY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Save translations"
],
"metadata": {
"id": "R3pP-eLLpo4a"
}
},
{
"cell_type": "code",
"source": [
"translations_file_name = \"all-filtered.esen.ms-multi-12.online.test.translated-ChatGPT-gpt-3.5-turbo-zero-shot.en\"\n",
"# translations_file_name = \"all-filtered.esen.ms-multi-12.online.test.translated-ChatGPT-gpt-3.5-turbo-one-shot.en\""
],
"metadata": {
"id": "BgHChX4sdY0u"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip3 install nltk -q\n",
"\n",
"import nltk\n",
"nltk.download(\"punkt\")"
],
"metadata": {
"id": "YBOYyUxh9A2M"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Save translations to a file\n",
"# This code also handles over-generation\n",
"\n",
"from nltk import sent_tokenize, word_tokenize\n",
"import os\n",
"\n",
"# ✳️ Where to save the translations\n",
"# It is better to connect Google Drive, and change 'directory'\n",
"directory = \"\"\n",
"output_file_name = translations_file_name\n",
"output_path = os.path.join(directory, output_file_name)\n",
"\n",
"with open(output_path, \"w+\") as translated_file:\n",
" for source, translation in zip(source_sentences, translations):\n",
" translation = translation.strip()\n",
" if \"\\n\" in translation:\n",
" translation = translation.split(\"\\n\")[0]\n",
" translated_file.write(translation.strip() + \"\\n\")\n",
" elif len(sent_tokenize(translation)) > len(sent_tokenize(source)) and len(word_tokenize(translation)) > len(word_tokenize(source))*2:\n",
" translation = sent_tokenize(translation)[0]\n",
" translated_file.write(translation.strip() + \"\\n\")\n",
" else:\n",
" translated_file.write(translation.strip() + \"\\n\")\n",
"\n",
"print(\"Translation file saved at:\", output_path)"
],
"metadata": {
"id": "MwQk9YgIpn4v"
},
"execution_count": null,
"outputs": []
}
]
}
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
{
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 在 GPU 上使用DeepSpeed-Inference加速GPT的Inference过程\n",
"\n",
"使用Hugging Face Transformers和DeepSpeed-Inference来优化GPT-2/GPT-J的推理性能。\n",
"\n",
"- 设置开发环境\n",
"- 加载初始的GPT-J模型并设定基线\n",
"- 使用DeepSpeed的InferenceEngine优化GPT-J以适配GPU\n",
"- 评估性能和速度\n",
"\n",
"注意!本教程是在包含NVIDIA T4的g4dn.2xlarge AWS EC2实例上创建和运行的,请自行学习如何使用云服务器 或 云服务器大模型服务提供商。\n",
"\n",
"---\n",
"\n",
"## 什么是Deepspeed Inference\n",
"\n",
"- DeepSpeed-Inference是DeepSpeed框架的扩展,用于提升推理工作负载。\n",
"- DeepSpeed Inference结合了张量并行(tensor parallelism)、流水线并行(pipeline-parallelism)等模型并行技术,并使用了自定义优化的CUDA内核。\n",
"- DeepSpeed为使用DeepSpeed、Megatron和HuggingFace训练的兼容Transformer模型提供了无缝的推理模式。\n",
"- 举例来说,DeepSpeed-Inference集成了模型并行技术,允许您在多GPU上运行大型语言模型(LLM)的推理,如具有1760亿参数的BLOOM。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 设置开发环境\n",
"\n",
"- 安装DeepSpeed,以及PyTorch、Transformers和其他一些库。\n",
"- 运行以下代码单元格将安装所有必需的包。\n",
"\n",
"_注意:需要一台带有GPU并安装了兼容的CUDA的机器。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install torch==1.11.0 torchvision==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 --upgrade -q \n",
"# !pip install deepspeed==0.7.2 --upgrade -q \n",
"!pip install git+https://github.com/microsoft/DeepSpeed.git@ds-inference/support-large-token-length --upgrade\n",
"!pip install transformers[sentencepiece]==4.21.2 accelerate --upgrade -q \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 在开始之前,我们需要确认所有的packages都正常安装。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import re # 导入正则表达式库re,用于字符串匹配\n",
"import torch # 导入PyTorch库\n",
"\n",
"# 检查DeepSpeed的安装情况\n",
"# 使用deepspeed.env_report命令输出当前DeepSpeed环境的信息\n",
"report = !python3 -m deepspeed.env_report\n",
"\n",
"# 使用正则表达式编译一个模式,用于匹配输出中的'ninja'状态是否为'OKAY'\n",
"r = re.compile('.*ninja.*OKAY.*')\n",
"\n",
"# 断言判断,如果report中没有匹配到'ninja' OKAY状态,则抛出异常提示DeepSpeed Inference未正确安装\n",
"assert any(r.match(line) for line in report) == True, \"DeepSpeed Inference not correctly installed\"\n",
"\n",
"# 检查CUDA和PyTorch版本\n",
"# 从torch.__version__获取torch和cuda版本信息\n",
"torch_version, cuda_version = torch.__version__.split(\"+\")\n",
"\n",
"# 只保留torch版本的前两位,比如从'1.9.1'中提取'1.9'\n",
"torch_version = \".\".join(torch_version.split(\".\")[:2])\n",
"\n",
"# 格式化CUDA版本为标准显示格式,例如'cu101'变为'10.1'\n",
"cuda_version = f\"{cuda_version[2:4]}.{cuda_version[4:]}\"\n",
"\n",
"# 正则表达式用于匹配DeepSpeed报告中的torch版本信息\n",
"r = re.compile(f'.*torch.*{torch_version}.*')\n",
"# 如果版本不匹配,抛出错误提示\n",
"assert any(r.match(line) for line in report) == True, \"Wrong Torch version\"\n",
"\n",
"# 正则表达式用于匹配DeepSpeed报告中的cuda版本信息\n",
"r = re.compile(f'.*cuda.*{cuda_version}.*')\n",
"# 如果版本不匹配,抛出错误提示\n",
"assert any(r.match(line) for line in report) == True, \"Wrong Cuda version\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 加载原生 GPT-J 模型并设置baseline\n",
"\n",
"- 在设置好环境后,为模型创建一个基线\n",
"- 此处使用的是EleutherAI/gpt-j-6B,一个由EleutherAI训练的GPT-J 6B模型。该模型在一个大规模的精选数据集The Pile上进行了训练\n",
"- 训练过程中,使用TPU v3-256 pod在383,500步内处理了4020亿个tokens\n",
"- 它作为一个自回归语言模型(autoregressive language model)进行训练\n",
"- 使用交叉熵损失(cross-entropy loss)来最大化正确预测下一个token的概率\n",
"\n",
"- 使用transformers加载模型并运行推理,创建基线。\n",
"\n",
"_注意:这里创建了一个单独的仓库,其中包含分片的fp16权重,以便通过使用device_map功能自动将分片的检查点加载到GPU上,从而更容易在较小的CPU上加载模型_"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model is loaded on device cuda\n"
]
}
],
"source": [
"import torch # 导入PyTorch库\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # 导入Transformers库中的必要模块\n",
"\n",
"# Hugging Face模型库中的模型仓库ID\n",
"model_id = \"philschmid/gpt-j-6B-fp16-sharded\"\n",
"\n",
"# 加载模型和分词器\n",
"# 使用AutoTokenizer加载分词器,该分词器与模型ID匹配\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"# 使用AutoModelForCausalLM加载因果语言模型\n",
"# 这里设置torch_dtype为float16以减少内存使用,同时使用device_map=\"auto\"自动将所有模型分片放置到GPU上\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map=\"auto\")\n",
"\n",
"# 输出确认模型已加载到设备上\n",
"print(f\"model is loaded on device {model.device.type}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"input payload: \n",
" \n",
"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend but it\n",
"prediction: \n",
" \n",
" 's Friday evening for the British and you can feel that coming in on top of a Friday, please try to spend a quiet time tonight. Thankyou, Philipp\n",
"\n",
"Annette\n",
"\n",
"Thank you for your reply to my last email. Regarding your issue with your new credit card please forward your request by email to \"customer.service@lodging.com\" In order for this to happen the email you send will need to include your full name, card number and the account number that your new card is linked to. Your credit card account number should be at the top of the email to avoid any misinterpretation of the request\n"
]
}
],
"source": [
"# 定义输入文本(payload),这是需要模型生成回复的内容\n",
"payload = (\"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. \"\n",
" \"What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this \"\n",
" \"email in the next 7 days. Best regards and have a nice weekend but it\")\n",
"\n",
"# 使用分词器将输入文本转换为模型可接受的输入ID,并将其放置到与模型相同的设备上(如GPU)\n",
"input_ids = tokenizer(payload, return_tensors=\"pt\").input_ids.to(model.device)\n",
"\n",
"# 打印输入的payload内容\n",
"print(f\"input payload: \\n\\n{payload}\")\n",
"\n",
"# 使用加载的模型进行文本生成推理\n",
"# 参数解释:\n",
"# - `do_sample=True`: 表示使用采样方法生成文本(而不是贪婪搜索)\n",
"# - `num_beams=1`: 表示使用单束搜索(即不进行束搜索优化)\n",
"# - `min_length=128`: 生成的最小长度为128个token\n",
"# - `max_new_tokens=128`: 最大生成128个新的token\n",
"logits = model.generate(input_ids, do_sample=True, num_beams=1, min_length=128, max_new_tokens=128)\n",
"\n",
"# 打印模型生成的预测输出\n",
"# 使用`tokenizer.decode`方法将生成的token ID转回人类可读的文本\n",
"# `logits[0].tolist()[len(input_ids[0]):]`用于提取生成的部分,而不是包括输入部分\n",
"print(f\"prediction: \\n\\n{tokenizer.decode(logits[0].tolist()[len(input_ids[0]):])}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"\"My name is Philipp and I'm from Germany.\\nI have been a collector of music since I am a small child. I own about 4500 vinyls and 10k CDs but my collection is by no means the most important thing in my life. I am married with two kids.\\nI got into the IT business many years ago, worked in the game industry where I learned a lot about programming, but now I live on the other side of the fence. We sell our house and are\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 测试模型生成能力\n",
"# 定义一个简单的输入示例\n",
"example = \"My name is Philipp and I\"\n",
"\n",
"# 使用分词器将输入示例转换为模型可接受的输入ID,并将其转移到模型所在的设备(GPU)\n",
"input_ids = tokenizer(example, return_tensors=\"pt\").input_ids.to(model.device)\n",
"\n",
"# 使用加载的模型生成文本\n",
"# 参数解释:\n",
"# - `do_sample=True`: 表示使用采样策略来生成输出,增加多样性\n",
"# - `max_length=100`: 生成的文本最大长度为100个token\n",
"logits = model.generate(input_ids, do_sample=True, max_length=100)\n",
"\n",
"# 将生成的token ID转换为人类可读的文本\n",
"output_text = tokenizer.decode(logits[0].tolist())\n",
"\n",
"# 打印生成的文本结果\n",
"print(output_text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 使用measure_latency函数来创建延迟基线,该函数通过一个简单的Python循环来运行推理,并计算模型的平均延迟(avg)、中位延迟(mean)和95百分位延迟(p95)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from time import perf_counter # 从time库中导入perf_counter,用于高精度计时\n",
"import numpy as np # 导入numpy库,用于计算延迟的统计量\n",
"import transformers # 导入transformers库\n",
"\n",
"# 隐藏生成时的警告信息\n",
"transformers.logging.set_verbosity_error()\n",
"\n",
"def measure_latency(model, tokenizer, payload, generation_args={}, device=model.device):\n",
" \"\"\"\n",
" 测量模型的推理延迟。\n",
"\n",
" 参数:\n",
" - model: 已加载的语言模型\n",
" - tokenizer: 模型对应的分词器\n",
" - payload: 要进行推理的输入文本\n",
" - generation_args: 生成参数的字典(默认空)\n",
" - device: 运行推理的设备(默认为模型的设备)\n",
"\n",
" 返回值:\n",
" - 延迟统计信息的字符串格式\n",
" - p95延迟值\n",
" \"\"\"\n",
" # 使用分词器将输入文本转换为模型可接受的输入ID,并将其放置到指定设备上\n",
" input_ids = tokenizer(payload, return_tensors=\"pt\").input_ids.to(device)\n",
" latencies = [] # 初始化一个列表来存储每次推理的延迟时间\n",
"\n",
" # 预热模型(warm up)\n",
" for _ in range(2):\n",
" _ = model.generate(input_ids, **generation_args)\n",
"\n",
" # 测量推理延迟\n",
" for _ in range(10):\n",
" start_time = perf_counter() # 记录推理开始时间\n",
" _ = model.generate(input_ids, **generation_args) # 运行模型推理\n",
" latency = perf_counter() - start_time # 计算单次推理的延迟\n",
" latencies.append(latency) # 将延迟添加到列表中\n",
"\n",
" # 计算延迟的统计量\n",
" time_avg_ms = 1000 * np.mean(latencies) # 计算平均延迟(毫秒)\n",
" time_std_ms = 1000 * np.std(latencies) # 计算延迟的标准差(毫秒)\n",
" time_p95_ms = 1000 * np.percentile(latencies, 95) # 计算95百分位延迟(毫秒)\n",
"\n",
" # 返回格式化的延迟统计信息和95百分位延迟值\n",
" return f\"P95 latency (ms) - {time_p95_ms:.2f}; Average latency (ms) - {time_avg_ms:.2f} ± {time_std_ms:.2f};\", time_p95_ms\n",
"\n",
"# 示例使用\n",
"payload = \"My name is Philipp and I\"\n",
"generation_args = {'do_sample': True, 'max_length': 100}\n",
"\n",
"# 调用measure_latency函数并输出结果\n",
"latency_info, p95_latency = measure_latency(model, tokenizer, payload, generation_args)\n",
"print(latency_info)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"解码策略与生成设置\n",
"使用贪婪搜索(greedy search)作为解码策略,并将生成128个新的token,输入的长度也为128个token。\n",
"\n",
"贪婪搜索(Greedy Search):\n",
"\n",
"贪婪搜索是一种常见的解码策略,在生成文本时,每一步选择概率最高的下一个token。虽然这种方法能够快速生成文本,但容易陷入局部最优解,可能导致生成的文本缺乏多样性。\n",
"与其他解码策略(如束搜索、采样方法)相比,贪婪搜索通常更快,但可能不如其他方法生成的文本质量高,尤其是在需要生成更复杂或更具创意的内容时。\n",
"生成设置:\n",
"在推理过程中,将模型的输入设置为128个token,这意味着输入文本将被分割为128个token的序列。\n",
"然后,模型将基于输入生成新的文本,生成的输出长度也为128个token。目的是确保生成的输出具有足够的上下文,以便进行有效的性能评估。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Payload sequence length is: 128\n",
"Vanilla model: P95 latency (ms) - 8985.898722249989; Average latency (ms) - 8955.07 +\\- 24.34;\n"
]
}
],
"source": [
"# 扩展输入文本(payload),通过重复两次相同的内容来增加输入序列的长度\n",
"payload = (\n",
" \"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. \"\n",
" \"What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer \"\n",
" \"this email in the next 7 days. Best regards and have a nice weekend but it\"\n",
") * 2 # 将输入文本重复两次以扩展长度\n",
"\n",
"# 打印扩展后的输入序列长度\n",
"print(f'Payload sequence length is: {len(tokenizer(payload)[\"input_ids\"])}')\n",
"\n",
"# 生成的参数设置\n",
"generation_args = dict(\n",
" do_sample=False, # 不使用采样,使用贪婪搜索\n",
" num_beams=1, # 使用单束搜索(不进行束搜索优化)\n",
" min_length=128, # 生成的最小长度为128个token\n",
" max_new_tokens=128 # 最大生成128个新的token\n",
")\n",
"\n",
"# 使用测量延迟函数来评估未优化模型(Vanilla model)的延迟\n",
"vanilla_results = measure_latency(model, tokenizer, payload, generation_args)\n",
"\n",
"# 打印未优化模型的延迟结果\n",
"print(f\"Vanilla model: {vanilla_results[0]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"模型在生成128个token的情况下,达到了8.9秒的推理延迟,相当于每个token的生成时间为69毫秒"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 使用 DeepSpeeds 的 `InferenceEngine` 优化 GPT-J\n",
"\n",
"- 接下来也是最重要的一步是优化模型以在GPU上进行推理。\n",
"- 使用DeepSpeed的InferenceEngine来实现。\n",
"- InferenceEngine通过init_inference方法进行初始化。init_inference方法至少需要以下几个参数:\n",
"\n",
"- model: 需要优化的模型。\n",
"- mp_size: 使用的GPU数量(模型并行的数量)。\n",
"- dtype: 使用的数据类型(如float16)。\n",
"- replace_with_kernel_inject: 是否注入自定义CUDA内核。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch # 导入PyTorch库\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM # 导入Transformers库中的模型和分词器加载方法\n",
"import deepspeed # 导入DeepSpeed库\n",
"\n",
"# Hugging Face模型库中的模型仓库ID\n",
"model_id = \"philschmid/gpt-j-6B-fp16-sharded\"\n",
"\n",
"# 加载模型和分词器\n",
"# 使用AutoTokenizer加载分词器\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"# 使用AutoModelForCausalLM加载因果语言模型,设置权重数据类型为float16\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)\n",
"\n",
"# 初始化DeepSpeed推理引擎\n",
"# 参数解释:\n",
"# - model: 要优化的Transformer模型实例\n",
"# - mp_size: 使用的GPU数量(这里设置为1)\n",
"# - dtype: 模型权重的数据类型(这里设置为float16)\n",
"# - replace_method: 设置为\"auto\",让DeepSpeed自动识别需要替换的层\n",
"# - replace_with_kernel_inject: 设置为True,使用DeepSpeed的内核注入器替换默认的CUDA内核\n",
"ds_model = deepspeed.init_inference(\n",
" model=model, # 需要优化的模型\n",
" mp_size=1, # 使用1个GPU\n",
" dtype=torch.float16, # 使用半精度浮点数(fp16)以减少内存占用\n",
" replace_method=\"auto\", # 让DeepSpeed自动识别和替换需要优化的层\n",
" replace_with_kernel_inject=True # 使用DeepSpeed的内核注入器进行优化\n",
")\n",
"\n",
"# 打印确认模型已加载到哪个设备上\n",
"print(f\"model is loaded on device {ds_model.module.device}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在检查模型的计算图,验证原始的GPTJLayer已经被HFGPTJLayer替换,而HFGPTJLayer包含了DeepSpeedTransformerInference模块。\n",
"\n",
"```python\n",
"InferenceEngine(\n",
" (module): GPTJForCausalLM( # GPT-J的因果语言模型模块\n",
" (transformer): GPTJModel( # GPT-J的Transformer模型\n",
" (wte): Embedding(50400, 4096) # 词嵌入层(Embedding layer)\n",
" (drop): Dropout(p=0.0, inplace=False) # Dropout层,丢弃概率为0.0\n",
" (h): ModuleList( # 模型的主体部分是一个模块列表(包括多个Transformer层)\n",
" (0): DeepSpeedTransformerInference( # 使用DeepSpeed优化的Transformer推理模块\n",
" (attention): DeepSpeedSelfAttention() # 使用DeepSpeed优化的自注意力层\n",
" (mlp): DeepSpeedMLP() # 使用DeepSpeed优化的多层感知机(MLP)层\n",
" )\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from deepspeed.ops.transformer.inference import DeepSpeedTransformerInference # 导入DeepSpeed的推理优化模块\n",
"\n",
"# 断言检查:验证模型的第一个Transformer层是否是DeepSpeed优化的Transformer推理模块\n",
"assert isinstance(ds_model.module.transformer.h[0], DeepSpeedTransformerInference) == True, \"Model not successfully initialized\"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'My name is Philipp and I live in Freiburg in Germany and I have a project called Cenapen. After three months in development already it is finally finished – and it is a Linux based device / operating system on an ARM Cortex A9 processor on a Raspberry Pi.\\n\\nAt the moment it offers the possibility to store data locally, it can retrieve data from a local, networked or web based Sqlite database (I’m writing this tutorial while I’'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 测试模型生成能力\n",
"# 定义一个简单的输入示例\n",
"example = \"My name is Philipp and I\"\n",
"\n",
"# 使用分词器将输入示例转换为模型可接受的输入ID,并将其转移到模型所在的设备(GPU)\n",
"input_ids = tokenizer(example, return_tensors=\"pt\").input_ids.to(model.device)\n",
"\n",
"# 使用DeepSpeed优化后的模型生成文本\n",
"# 参数解释:\n",
"# - `do_sample=True`: 表示使用采样策略来生成输出,增加多样性\n",
"# - `max_length=100`: 生成的文本最大长度为100个token\n",
"logits = ds_model.generate(input_ids, do_sample=True, max_length=100)\n",
"\n",
"# 将生成的token ID转换为人类可读的文本\n",
"output_text = tokenizer.decode(logits[0].tolist())\n",
"\n",
"# 打印生成的文本结果\n",
"print(output_text)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 评价效率和速度\n",
"\n",
"作为最后一步,需要详细分析优化后的模型性能。应用优化技术(如图优化和混合精度)不仅会影响性能(延迟),还可能对模型的准确性产生影响。因此,加速模型往往伴随着一定的权衡。\n",
"\n",
"- 性能与准确性的权衡:\n",
"\n",
"- 通过图优化(graph optimizations)和混合精度(mixed-precision)等技术,可以显著提高模型推理的速度和降低延迟。\n",
"但是,这些技术也可能导致模型的准确性有所下降。对于实际应用,优化的目标需要在性能和准确性之间找到合适的平衡点。\n",
"\n",
"- 测试优化后的模型性能:\n",
"\n",
"- 使用和原始模型(vanilla model)相同的生成参数(generation_args)来测试优化后的模型性能。这确保了测试的公平性,可以直接对比优化前后的性能差异。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Payload sequence length is: 128\n",
"DeepSpeed model: P95 latency (ms) - 6577.044982599967; Average latency (ms) - 6569.11 +\\- 6.57;\n"
]
}
],
"source": [
"# 扩展输入文本(payload),通过重复两次相同的内容来增加输入序列的长度\n",
"payload = (\n",
" \"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. \"\n",
" \"What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer \"\n",
" \"this email in the next 7 days. Best regards and have a nice weekend but it\"\n",
" * 2 # 将输入文本重复两次以扩展长度\n",
")\n",
"\n",
"# 打印扩展后的输入序列长度\n",
"print(f'Payload sequence length is: {len(tokenizer(payload)[\"input_ids\"])}')\n",
"\n",
"# 生成参数设置\n",
"generation_args = dict(\n",
" do_sample=False, # 不使用采样,使用贪婪搜索\n",
" num_beams=1, # 使用单束搜索(不进行束搜索优化)\n",
" min_length=128, # 生成的最小长度为128个token\n",
" max_new_tokens=128 # 最大生成128个新的token\n",
")\n",
"\n",
"# 使用之前定义的measure_latency函数来测试DeepSpeed优化后的模型的推理延迟\n",
"# 参数包括:DeepSpeed优化后的模型、分词器、扩展后的输入文本、生成参数以及模型所在的设备\n",
"ds_results = measure_latency(ds_model, tokenizer, payload, generation_args, ds_model.module.device)\n",
"\n",
"# 打印DeepSpeed优化后的模型的延迟结果\n",
"print(f\"DeepSpeed model: {ds_results[0]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input payload: \n",
" \n",
"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend but it\n",
"prediction: \n",
" \n",
" 's not over yet.\n",
"\n",
"I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend but\n"
]
}
],
"source": [
"# 定义输入文本(payload)\n",
"payload = (\n",
" \"Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. \"\n",
" \"What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer \"\n",
" \"this email in the next 7 days. Best regards and have a nice weekend but it\"\n",
")\n",
"\n",
"# 使用分词器将输入文本转换为模型可接受的输入ID,并将其放置到与模型相同的设备(GPU)上\n",
"input_ids = tokenizer(payload, return_tensors=\"pt\").input_ids.to(model.device)\n",
"\n",
"# 打印输入的payload内容\n",
"print(f\"input payload: \\n\\n{payload}\")\n",
"\n",
"# 使用DeepSpeed优化后的模型生成文本\n",
"# 参数解释:\n",
"# - `do_sample=False`: 不使用采样,使用确定性的方法生成文本\n",
"# - `num_beams=2`: 使用束宽为2的束搜索(Beam Search)来生成输出\n",
"# - `min_length=64`: 生成的最小长度为64个token\n",
"# - `max_new_tokens=64`: 最大生成64个新的token\n",
"logits = ds_model.generate(input_ids, do_sample=False, num_beams=2, min_length=64, max_new_tokens=64)\n",
"\n",
"# 将生成的token ID转换为人类可读的文本\n",
"output_text = tokenizer.decode(logits[0].tolist()[len(input_ids[0]):])\n",
"\n",
"# 打印模型生成的预测输出\n",
"print(f\"prediction: \\n\\n{output_text}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"优化后的DeepSpeed模型在生成128个token时达到了6.5秒的推理延迟,相当于每个token的生成时间为50毫秒。\n",
"\n",
"性能提升分析\n",
"性能改进结果:\n",
"\n",
"优化前的GPT-J-6B模型生成128个token的延迟为8.9秒(即69毫秒/token)。\n",
"优化后的模型生成128个token的延迟降低到6.5秒(即50毫秒/token)。\n",
"提升幅度计算:\n",
"\n",
"优化前的每个token延迟为69毫秒,优化后为50毫秒。\n",
"计算性能提升倍数:\n",
"提升倍数1.38\n",
"结果表明,经过优化后,模型的推理速度提升了1.38倍。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('dev')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
-- "a/2-\345\272\224\347\224\250/2.17-\350\257\255\351\237\263\344\272\244\344\272\222/.gitkeep"
++ /dev/null
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOxhMeI0i4OtD359azGPZtt"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dP70MpFhWahb","executionInfo":{"status":"ok","timestamp":1708851088669,"user_tz":-480,"elapsed":13734,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"83526d38-ed46-4f4b-a65e-0087c41098ff"},"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: openai in /usr/local/lib/python3.10/dist-packages (1.12.0)\n","Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from openai) (0.27.0)\n","Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.2)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (1.0.4)\n","Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n","Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Collecting pydub\n"," Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n","Installing collected packages: pydub\n","Successfully installed pydub-0.25.1\n"]}],"source":["!pip install openai\n","!pip install pydub"]},{"cell_type":"code","source":["from openai import OpenAI\n","import os\n","import urllib\n","from IPython.display import Audio\n","from pathlib import Path\n","from pydub import AudioSegment\n","import ssl\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"],"metadata":{"id":"ekSFYNH-dNxs","executionInfo":{"status":"ok","timestamp":1708851096508,"user_tz":-480,"elapsed":4015,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["# set download paths\n","earnings_call_remote_filepath = \"https://cdn.openai.com/API/examples/data/EarningsCall.wav\"\n","\n","# set local save locations\n","earnings_call_filepath = \"EarningsCall.wav\"\n","\n","# download example audio files and save locally\n","ssl._create_default_https_context = ssl._create_unverified_context\n","urllib.request.urlretrieve(earnings_call_remote_filepath, earnings_call_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nvsx1S2wdbY9","executionInfo":{"status":"ok","timestamp":1708851102141,"user_tz":-480,"elapsed":1888,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"dab8851c-0edc-4e73-a63a-904b64e79013"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('EarningsCall.wav', <http.client.HTTPMessage at 0x7dc9986f70d0>)"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["# Function to detect leading silence\n","# Returns the number of milliseconds until the first sound (chunk averaging more than X decibels)\n","def milliseconds_until_sound(sound, silence_threshold_in_decibels=-20.0, chunk_size=10):\n"," trim_ms = 0 # ms\n","\n"," assert chunk_size > 0 # to avoid infinite loop\n"," while sound[trim_ms:trim_ms+chunk_size].dBFS < silence_threshold_in_decibels and trim_ms < len(sound):\n"," trim_ms += chunk_size\n","\n"," return trim_ms"],"metadata":{"id":"qBS1pqzfdcSz","executionInfo":{"status":"ok","timestamp":1708851279522,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["def trim_start(filepath):\n"," path = Path(filepath)\n"," directory = path.parent\n"," filename = path.name\n"," audio = AudioSegment.from_file(filepath, format=\"wav\")\n"," start_trim = milliseconds_until_sound(audio)\n"," trimmed = audio[start_trim:]\n"," new_filename = directory / f\"trimmed_{filename}\"\n"," trimmed.export(new_filename, format=\"wav\")\n"," return trimmed, new_filename"],"metadata":{"id":"0hIOI5YseTNg","executionInfo":{"status":"ok","timestamp":1708851290200,"user_tz":-480,"elapsed":342,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["def transcribe_audio(file,output_dir):\n"," audio_path = os.path.join(output_dir, file)\n"," with open(audio_path, 'rb') as audio_data:\n"," transcription = client.audio.transcriptions.create(\n"," model=\"whisper-1\", file=audio_data)\n"," return transcription.text"],"metadata":{"id":"7Fd5gvnxeVvw","executionInfo":{"status":"ok","timestamp":1708851297581,"user_tz":-480,"elapsed":487,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["# Define function to remove non-ascii characters\n","def remove_non_ascii(text):\n"," return ''.join(i for i in text if ord(i)<128)"],"metadata":{"id":"59HZSPGIeXe6","executionInfo":{"status":"ok","timestamp":1708851322999,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["# Define function to add punctuation\n","def punctuation_assistant(ascii_transcript):\n","\n"," system_prompt = \"\"\"You are a helpful assistant that adds punctuation to text.\n"," Preserve the original words and only insert necessary punctuation such as periods,\n"," commas, capialization, symbols like dollar sings or percentage signs, and formatting.\n"," Use only the context provided. If there is no context provided say, 'No context provided'\\n\"\"\"\n"," response = client.chat.completions.create(\n"," model=\"gpt-3.5-turbo\",\n"," temperature=0,\n"," messages=[\n"," {\n"," \"role\": \"system\",\n"," \"content\": system_prompt\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": ascii_transcript\n"," }\n"," ]\n"," )\n"," return response"],"metadata":{"id":"TmeaQdogedzk","executionInfo":{"status":"ok","timestamp":1708851345970,"user_tz":-480,"elapsed":503,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["# Define function to fix product mispellings\n","def product_assistant(ascii_transcript):\n"," system_prompt = \"\"\"You are an intelligent assistant specializing in financial products;\n"," your task is to process transcripts of earnings calls, ensuring that all references to\n"," financial products and common financial terms are in the correct format. For each\n"," financial product or common term that is typically abbreviated as an acronym, the full term\n"," should be spelled out followed by the acronym in parentheses. For example, '401k' should be\n"," transformed to '401(k) retirement savings plan', 'HSA' should be transformed to 'Health Savings Account (HSA)'\n"," , 'ROA' should be transformed to 'Return on Assets (ROA)', 'VaR' should be transformed to 'Value at Risk (VaR)'\n",", and 'PB' should be transformed to 'Price to Book (PB) ratio'. Similarly, transform spoken numbers representing\n","financial products into their numeric representations, followed by the full name of the product in parentheses.\n","For instance, 'five two nine' to '529 (Education Savings Plan)' and 'four zero one k' to '401(k) (Retirement Savings Plan)'.\n"," However, be aware that some acronyms can have different meanings based on the context (e.g., 'LTV' can stand for\n","'Loan to Value' or 'Lifetime Value'). You will need to discern from the context which term is being referred to\n","and apply the appropriate transformation. In cases where numerical figures or metrics are spelled out but do not\n","represent specific financial products (like 'twenty three percent'), these should be left as is. Your role is to\n"," analyze and adjust financial product terminology in the text. Once you've done that, produce the adjusted\n"," transcript and a list of the words you've changed\"\"\"\n"," response = client.chat.completions.create(\n"," model=\"gpt-4\",\n"," temperature=0,\n"," messages=[\n"," {\n"," \"role\": \"system\",\n"," \"content\": system_prompt\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": ascii_transcript\n"," }\n"," ]\n"," )\n"," return response"],"metadata":{"id":"DiveGkRIejPY","executionInfo":{"status":"ok","timestamp":1708851409093,"user_tz":-480,"elapsed":358,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["# Trim the start of the original audio file\n","trimmed_audio = trim_start(earnings_call_filepath)"],"metadata":{"id":"OGnRN5Hseyu4","executionInfo":{"status":"ok","timestamp":1708851422061,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["trimmed_audio, trimmed_filename = trim_start(earnings_call_filepath)"],"metadata":{"id":"vzt0dwHse2Ax","executionInfo":{"status":"ok","timestamp":1708851434593,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":14,"outputs":[]},{"cell_type":"code","source":["# Segment audio\n","trimmed_audio = AudioSegment.from_wav(trimmed_filename) # Load the trimmed audio file\n","\n","one_minute = 1 * 60 * 1000 # Duration for each segment (in milliseconds)\n","\n","start_time = 0 # Start time for the first segment\n","\n","i = 0 # Index for naming the segmented files\n","\n","output_dir_trimmed = \"trimmed_earnings_directory\" # Output directory for the segmented files\n","\n","if not os.path.isdir(output_dir_trimmed): # Create the output directory if it does not exist\n"," os.makedirs(output_dir_trimmed)\n","\n","while start_time < len(trimmed_audio): # Loop over the trimmed audio file\n"," segment = trimmed_audio[start_time:start_time + one_minute] # Extract a segment\n"," segment.export(os.path.join(output_dir_trimmed, f\"trimmed_{i:02d}.wav\"), format=\"wav\") # Save the segment\n"," start_time += one_minute # Update the start time for the next segment\n"," i += 1 # Increment the index for naming the next file"],"metadata":{"id":"Lk7d7Gvme5Dk","executionInfo":{"status":"ok","timestamp":1708851447148,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":15,"outputs":[]},{"cell_type":"code","source":["# Get list of trimmed and segmented audio files and sort them numerically\n","audio_files = sorted(\n"," (f for f in os.listdir(output_dir_trimmed) if f.endswith(\".wav\")),\n"," key=lambda f: int(''.join(filter(str.isdigit, f)))\n",")"],"metadata":{"id":"ouDjywBve8Fj","executionInfo":{"status":"ok","timestamp":1708851461397,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["# Use a loop to apply the transcribe function to all audio files\n","transcriptions = [transcribe_audio(file, output_dir_trimmed) for file in audio_files]"],"metadata":{"id":"TUpKUV2ke_gn","executionInfo":{"status":"ok","timestamp":1708851479212,"user_tz":-480,"elapsed":8685,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":17,"outputs":[]},{"cell_type":"code","source":["# Concatenate the transcriptions\n","full_transcript = ' '.join(transcriptions)"],"metadata":{"id":"f5IUwxVBfB2O","executionInfo":{"status":"ok","timestamp":1708851481684,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["print(full_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"A_A1CUa_fEk3","executionInfo":{"status":"ok","timestamp":1708851485687,"user_tz":-480,"elapsed":5,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"4e390258-f3d6-4c6b-9e3c-6de42ef5478f"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone. And welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized. debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt to equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99%... confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting edge blockchain solutions and AI driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million. Significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Remove non-ascii characters from the transcript\n","ascii_transcript = remove_non_ascii(full_transcript)"],"metadata":{"id":"kjFntLyPfFh7","executionInfo":{"status":"ok","timestamp":1708851493166,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":20,"outputs":[]},{"cell_type":"code","source":["\n","print(ascii_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5m5mZTL0fHTz","executionInfo":{"status":"ok","timestamp":1708851498271,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"d72c78e7-e6e8-4094-9f88-7df65b7665fc"},"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone. And welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized. debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt to equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99%... confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting edge blockchain solutions and AI driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million. Significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Use punctuation assistant function\n","response = punctuation_assistant(ascii_transcript)"],"metadata":{"id":"c2XFW-VGfIn-","executionInfo":{"status":"ok","timestamp":1708851512225,"user_tz":-480,"elapsed":7046,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":22,"outputs":[]},{"cell_type":"code","source":["# Extract the punctuated transcript from the model's response\n","punctuated_transcript = response.choices[0].message.content"],"metadata":{"id":"prOqzj3ffKTu","executionInfo":{"status":"ok","timestamp":1708851512902,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":23,"outputs":[]},{"cell_type":"code","source":["print(punctuated_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DLbooVITfMK0","executionInfo":{"status":"ok","timestamp":1708851522144,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"087b5ac5-56ba-46c3-d063-2b873d155361"},"execution_count":24,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone, and welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk-adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt-to-equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99% confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting-edge blockchain solutions and AI-driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million, significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us, and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Use product assistant function\n","response = product_assistant(punctuated_transcript)"],"metadata":{"id":"gDXx2uFgfOcE","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":23541,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":25,"outputs":[]},{"cell_type":"code","source":["# Extract the final transcript from the model's response\n","final_transcript = response.choices[0].message.content"],"metadata":{"id":"gC2f0PJGfPk7","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":19,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":26,"outputs":[]},{"cell_type":"code","source":["print(final_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"CPnsDbxzfQ8j","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":18,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"a321937d-6557-4c33-d699-e5a7ecea76dd"},"execution_count":27,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone, and welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar second quarter (Q2) with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our Earnings Before Interest, Taxes, Depreciation, and Amortization (EBITDA) has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in Collateralized Debt Obligations (CDOs), and Residential Mortgage-Backed Securities (RMBS). We've also invested $25 million in AAA rated corporate bonds, enhancing our risk-adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our Debt-to-Equity (D/E) ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with Customer Acquisition Cost (CAC) dropping by 15% and Lifetime Value (LTV) growing by 25%. Our LTV to CAC (LTVCAC) ratio is at an impressive 3.5%. In terms of risk management, we have a Value at Risk (VaR) model in place with a 99% confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy Tier 1 Capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting-edge blockchain solutions and AI-driven predictive analytics. We're also excited about the upcoming Initial Public Offering (IPO) of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million, significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us, and we look forward to an even more successful Q3. Thank you so much.\n","\n","Words Changed:\n","1. Q2 to second quarter (Q2)\n","2. EBITDA to Earnings Before Interest, Taxes, Depreciation, and Amortization (EBITDA)\n","3. CDOs to Collateralized Debt Obligations (CDOs)\n","4. RMBS to Residential Mortgage-Backed Securities (RMBS)\n","5. D/E to Debt-to-Equity (D/E)\n","6. CAC to Customer Acquisition Cost (CAC)\n","7. LTV to Lifetime Value (LTV)\n","8. LTVCAC to LTV to CAC (LTVCAC)\n","9. VaR to Value at Risk (VaR)\n","10. IPO to Initial Public Offering (IPO)\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"1n9UroAxfSNy"},"execution_count":null,"outputs":[]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOxhMeI0i4OtD359azGPZtt"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dP70MpFhWahb","executionInfo":{"status":"ok","timestamp":1708851088669,"user_tz":-480,"elapsed":13734,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"83526d38-ed46-4f4b-a65e-0087c41098ff"},"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: openai in /usr/local/lib/python3.10/dist-packages (1.12.0)\n","Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from openai) (0.27.0)\n","Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.2)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (1.0.4)\n","Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n","Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Collecting pydub\n"," Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n","Installing collected packages: pydub\n","Successfully installed pydub-0.25.1\n"]}],"source":["!pip install openai\n","!pip install pydub"]},{"cell_type":"code","source":["from openai import OpenAI\n","import os\n","import urllib\n","from IPython.display import Audio\n","from pathlib import Path\n","from pydub import AudioSegment\n","import ssl\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"],"metadata":{"id":"ekSFYNH-dNxs","executionInfo":{"status":"ok","timestamp":1708851096508,"user_tz":-480,"elapsed":4015,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["# set download paths\n","earnings_call_remote_filepath = \"https://cdn.openai.com/API/examples/data/EarningsCall.wav\"\n","\n","# set local save locations\n","earnings_call_filepath = \"EarningsCall.wav\"\n","\n","# download example audio files and save locally\n","ssl._create_default_https_context = ssl._create_unverified_context\n","urllib.request.urlretrieve(earnings_call_remote_filepath, earnings_call_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nvsx1S2wdbY9","executionInfo":{"status":"ok","timestamp":1708851102141,"user_tz":-480,"elapsed":1888,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"dab8851c-0edc-4e73-a63a-904b64e79013"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('EarningsCall.wav', <http.client.HTTPMessage at 0x7dc9986f70d0>)"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["# Function to detect leading silence\n","# Returns the number of milliseconds until the first sound (chunk averaging more than X decibels)\n","def milliseconds_until_sound(sound, silence_threshold_in_decibels=-20.0, chunk_size=10):\n"," trim_ms = 0 # ms\n","\n"," assert chunk_size > 0 # to avoid infinite loop\n"," while sound[trim_ms:trim_ms+chunk_size].dBFS < silence_threshold_in_decibels and trim_ms < len(sound):\n"," trim_ms += chunk_size\n","\n"," return trim_ms"],"metadata":{"id":"qBS1pqzfdcSz","executionInfo":{"status":"ok","timestamp":1708851279522,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["def trim_start(filepath):\n"," path = Path(filepath)\n"," directory = path.parent\n"," filename = path.name\n"," audio = AudioSegment.from_file(filepath, format=\"wav\")\n"," start_trim = milliseconds_until_sound(audio)\n"," trimmed = audio[start_trim:]\n"," new_filename = directory / f\"trimmed_{filename}\"\n"," trimmed.export(new_filename, format=\"wav\")\n"," return trimmed, new_filename"],"metadata":{"id":"0hIOI5YseTNg","executionInfo":{"status":"ok","timestamp":1708851290200,"user_tz":-480,"elapsed":342,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["def transcribe_audio(file,output_dir):\n"," audio_path = os.path.join(output_dir, file)\n"," with open(audio_path, 'rb') as audio_data:\n"," transcription = client.audio.transcriptions.create(\n"," model=\"whisper-1\", file=audio_data)\n"," return transcription.text"],"metadata":{"id":"7Fd5gvnxeVvw","executionInfo":{"status":"ok","timestamp":1708851297581,"user_tz":-480,"elapsed":487,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["# Define function to remove non-ascii characters\n","def remove_non_ascii(text):\n"," return ''.join(i for i in text if ord(i)<128)"],"metadata":{"id":"59HZSPGIeXe6","executionInfo":{"status":"ok","timestamp":1708851322999,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["# Define function to add punctuation\n","def punctuation_assistant(ascii_transcript):\n","\n"," system_prompt = \"\"\"You are a helpful assistant that adds punctuation to text.\n"," Preserve the original words and only insert necessary punctuation such as periods,\n"," commas, capialization, symbols like dollar sings or percentage signs, and formatting.\n"," Use only the context provided. If there is no context provided say, 'No context provided'\\n\"\"\"\n"," response = client.chat.completions.create(\n"," model=\"gpt-3.5-turbo\",\n"," temperature=0,\n"," messages=[\n"," {\n"," \"role\": \"system\",\n"," \"content\": system_prompt\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": ascii_transcript\n"," }\n"," ]\n"," )\n"," return response"],"metadata":{"id":"TmeaQdogedzk","executionInfo":{"status":"ok","timestamp":1708851345970,"user_tz":-480,"elapsed":503,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["# Define function to fix product mispellings\n","def product_assistant(ascii_transcript):\n"," system_prompt = \"\"\"You are an intelligent assistant specializing in financial products;\n"," your task is to process transcripts of earnings calls, ensuring that all references to\n"," financial products and common financial terms are in the correct format. For each\n"," financial product or common term that is typically abbreviated as an acronym, the full term\n"," should be spelled out followed by the acronym in parentheses. For example, '401k' should be\n"," transformed to '401(k) retirement savings plan', 'HSA' should be transformed to 'Health Savings Account (HSA)'\n"," , 'ROA' should be transformed to 'Return on Assets (ROA)', 'VaR' should be transformed to 'Value at Risk (VaR)'\n",", and 'PB' should be transformed to 'Price to Book (PB) ratio'. Similarly, transform spoken numbers representing\n","financial products into their numeric representations, followed by the full name of the product in parentheses.\n","For instance, 'five two nine' to '529 (Education Savings Plan)' and 'four zero one k' to '401(k) (Retirement Savings Plan)'.\n"," However, be aware that some acronyms can have different meanings based on the context (e.g., 'LTV' can stand for\n","'Loan to Value' or 'Lifetime Value'). You will need to discern from the context which term is being referred to\n","and apply the appropriate transformation. In cases where numerical figures or metrics are spelled out but do not\n","represent specific financial products (like 'twenty three percent'), these should be left as is. Your role is to\n"," analyze and adjust financial product terminology in the text. Once you've done that, produce the adjusted\n"," transcript and a list of the words you've changed\"\"\"\n"," response = client.chat.completions.create(\n"," model=\"gpt-4\",\n"," temperature=0,\n"," messages=[\n"," {\n"," \"role\": \"system\",\n"," \"content\": system_prompt\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": ascii_transcript\n"," }\n"," ]\n"," )\n"," return response"],"metadata":{"id":"DiveGkRIejPY","executionInfo":{"status":"ok","timestamp":1708851409093,"user_tz":-480,"elapsed":358,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["# Trim the start of the original audio file\n","trimmed_audio = trim_start(earnings_call_filepath)"],"metadata":{"id":"OGnRN5Hseyu4","executionInfo":{"status":"ok","timestamp":1708851422061,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["trimmed_audio, trimmed_filename = trim_start(earnings_call_filepath)"],"metadata":{"id":"vzt0dwHse2Ax","executionInfo":{"status":"ok","timestamp":1708851434593,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":14,"outputs":[]},{"cell_type":"code","source":["# Segment audio\n","trimmed_audio = AudioSegment.from_wav(trimmed_filename) # Load the trimmed audio file\n","\n","one_minute = 1 * 60 * 1000 # Duration for each segment (in milliseconds)\n","\n","start_time = 0 # Start time for the first segment\n","\n","i = 0 # Index for naming the segmented files\n","\n","output_dir_trimmed = \"trimmed_earnings_directory\" # Output directory for the segmented files\n","\n","if not os.path.isdir(output_dir_trimmed): # Create the output directory if it does not exist\n"," os.makedirs(output_dir_trimmed)\n","\n","while start_time < len(trimmed_audio): # Loop over the trimmed audio file\n"," segment = trimmed_audio[start_time:start_time + one_minute] # Extract a segment\n"," segment.export(os.path.join(output_dir_trimmed, f\"trimmed_{i:02d}.wav\"), format=\"wav\") # Save the segment\n"," start_time += one_minute # Update the start time for the next segment\n"," i += 1 # Increment the index for naming the next file"],"metadata":{"id":"Lk7d7Gvme5Dk","executionInfo":{"status":"ok","timestamp":1708851447148,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":15,"outputs":[]},{"cell_type":"code","source":["# Get list of trimmed and segmented audio files and sort them numerically\n","audio_files = sorted(\n"," (f for f in os.listdir(output_dir_trimmed) if f.endswith(\".wav\")),\n"," key=lambda f: int(''.join(filter(str.isdigit, f)))\n",")"],"metadata":{"id":"ouDjywBve8Fj","executionInfo":{"status":"ok","timestamp":1708851461397,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["# Use a loop to apply the transcribe function to all audio files\n","transcriptions = [transcribe_audio(file, output_dir_trimmed) for file in audio_files]"],"metadata":{"id":"TUpKUV2ke_gn","executionInfo":{"status":"ok","timestamp":1708851479212,"user_tz":-480,"elapsed":8685,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":17,"outputs":[]},{"cell_type":"code","source":["# Concatenate the transcriptions\n","full_transcript = ' '.join(transcriptions)"],"metadata":{"id":"f5IUwxVBfB2O","executionInfo":{"status":"ok","timestamp":1708851481684,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["print(full_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"A_A1CUa_fEk3","executionInfo":{"status":"ok","timestamp":1708851485687,"user_tz":-480,"elapsed":5,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"4e390258-f3d6-4c6b-9e3c-6de42ef5478f"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone. And welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized. debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt to equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99%... confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting edge blockchain solutions and AI driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million. Significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Remove non-ascii characters from the transcript\n","ascii_transcript = remove_non_ascii(full_transcript)"],"metadata":{"id":"kjFntLyPfFh7","executionInfo":{"status":"ok","timestamp":1708851493166,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":20,"outputs":[]},{"cell_type":"code","source":["\n","print(ascii_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5m5mZTL0fHTz","executionInfo":{"status":"ok","timestamp":1708851498271,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"d72c78e7-e6e8-4094-9f88-7df65b7665fc"},"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone. And welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized. debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt to equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99%... confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting edge blockchain solutions and AI driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million. Significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Use punctuation assistant function\n","response = punctuation_assistant(ascii_transcript)"],"metadata":{"id":"c2XFW-VGfIn-","executionInfo":{"status":"ok","timestamp":1708851512225,"user_tz":-480,"elapsed":7046,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":22,"outputs":[]},{"cell_type":"code","source":["# Extract the punctuated transcript from the model's response\n","punctuated_transcript = response.choices[0].message.content"],"metadata":{"id":"prOqzj3ffKTu","executionInfo":{"status":"ok","timestamp":1708851512902,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":23,"outputs":[]},{"cell_type":"code","source":["print(punctuated_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DLbooVITfMK0","executionInfo":{"status":"ok","timestamp":1708851522144,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"087b5ac5-56ba-46c3-d063-2b873d155361"},"execution_count":24,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone, and welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar Q2 with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our EBITDA has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in collateralized debt obligations, and residential mortgage-backed securities. We've also invested $25 million in AAA rated corporate bonds, enhancing our risk-adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our debt-to-equity ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with customer acquisition cost dropping by 15% and lifetime value growing by 25%. Our LTVCAC ratio is at an impressive 3.5%. In terms of risk management, we have a value-at-risk model in place with a 99% confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy tier one capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting-edge blockchain solutions and AI-driven predictive analytics. We're also excited about the upcoming IPO of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million, significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us, and we look forward to an even more successful Q3. Thank you so much.\n"]}]},{"cell_type":"code","source":["# Use product assistant function\n","response = product_assistant(punctuated_transcript)"],"metadata":{"id":"gDXx2uFgfOcE","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":23541,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":25,"outputs":[]},{"cell_type":"code","source":["# Extract the final transcript from the model's response\n","final_transcript = response.choices[0].message.content"],"metadata":{"id":"gC2f0PJGfPk7","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":19,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":26,"outputs":[]},{"cell_type":"code","source":["print(final_transcript)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"CPnsDbxzfQ8j","executionInfo":{"status":"ok","timestamp":1708851550633,"user_tz":-480,"elapsed":18,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"a321937d-6557-4c33-d699-e5a7ecea76dd"},"execution_count":27,"outputs":[{"output_type":"stream","name":"stdout","text":["Good afternoon, everyone, and welcome to FinTech Plus Sync's second quarter 2023 earnings call. I'm John Doe, CEO of FinTech Plus. We've had a stellar second quarter (Q2) with a revenue of 125 million, a 25% increase year over year. Our gross profit margin stands at a solid 58%, due in part to cost efficiencies gained from our scalable business model. Our Earnings Before Interest, Taxes, Depreciation, and Amortization (EBITDA) has surged to 37.5 million, translating to a remarkable 30% EBITDA margin. Our net income for the quarter rose to 16 million, which is a noteworthy increase from 10 million in Q2 2022. Our total addressable market has grown substantially thanks to the expansion of our high yield savings product line and the new RoboAdvisor platform. We've been diversifying our asset-backed securities portfolio, investing heavily in Collateralized Debt Obligations (CDOs), and Residential Mortgage-Backed Securities (RMBS). We've also invested $25 million in AAA rated corporate bonds, enhancing our risk-adjusted returns. As for our balance sheet, total assets reached $1.5 billion with total liabilities at $900 million, leaving us with a solid equity base of $600 million. Our Debt-to-Equity (D/E) ratio stands at 1.5, a healthy figure considering our expansionary phase. We continue to see substantial organic user growth, with Customer Acquisition Cost (CAC) dropping by 15% and Lifetime Value (LTV) growing by 25%. Our LTV to CAC (LTVCAC) ratio is at an impressive 3.5%. In terms of risk management, we have a Value at Risk (VaR) model in place with a 99% confidence level indicating that our maximum loss will not exceed 5 million in the next trading day. We've adopted a conservative approach to managing our leverage and have a healthy Tier 1 Capital ratio of 12.5%. Our forecast for the coming quarter is positive. We expect revenue to be around 135 million and 8% quarter over quarter growth driven primarily by our cutting-edge blockchain solutions and AI-driven predictive analytics. We're also excited about the upcoming Initial Public Offering (IPO) of our FinTech subsidiary, Pay Plus, which we expect to raise 200 million, significantly bolstering our liquidity and paving the way for aggressive growth strategies. We thank our shareholders for their continued faith in us, and we look forward to an even more successful Q3. Thank you so much.\n","\n","Words Changed:\n","1. Q2 to second quarter (Q2)\n","2. EBITDA to Earnings Before Interest, Taxes, Depreciation, and Amortization (EBITDA)\n","3. CDOs to Collateralized Debt Obligations (CDOs)\n","4. RMBS to Residential Mortgage-Backed Securities (RMBS)\n","5. D/E to Debt-to-Equity (D/E)\n","6. CAC to Customer Acquisition Cost (CAC)\n","7. LTV to Lifetime Value (LTV)\n","8. LTVCAC to LTV to CAC (LTVCAC)\n","9. VaR to Value at Risk (VaR)\n","10. IPO to Initial Public Offering (IPO)\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"1n9UroAxfSNy"},"execution_count":null,"outputs":[]}]}
\ No newline at end of file
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOUYiHanYAsDKHOE8u614e+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install openai"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"coTGz0rOTQgc","executionInfo":{"status":"ok","timestamp":1708848399555,"user_tz":-480,"elapsed":8409,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"68f777e2-463f-4184-9607-acd39df74922"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/226.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.4/226.7 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.27.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.2)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.4-py3-none-any.whl (77 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.8/77.8 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","Successfully installed h11-0.14.0 httpcore-1.0.4 httpx-0.27.0 openai-1.12.0\n"]}]},{"cell_type":"code","execution_count":4,"metadata":{"id":"zXxzVeSkS5L0","executionInfo":{"status":"ok","timestamp":1708848410118,"user_tz":-480,"elapsed":5242,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"outputs":[],"source":["# imports\n","from openai import OpenAI # for making OpenAI API calls\n","import urllib # for downloading example audio files\n","import os # for accessing environment variables\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"]},{"cell_type":"code","source":["# set download paths\n","ZyntriQix_remote_filepath = \"https://cdn.openai.com/API/examples/data/ZyntriQix.wav\"\n","\n","\n","# set local save locations\n","ZyntriQix_filepath = \"ZyntriQix.wav\"\n","\n","# download example audio files and save locally\n","urllib.request.urlretrieve(ZyntriQix_remote_filepath, ZyntriQix_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BPlMaAdhTOIX","executionInfo":{"status":"ok","timestamp":1708848430837,"user_tz":-480,"elapsed":410,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cc0217f6-7606-46fe-b2e8-7796078149b4"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('ZyntriQix.wav', <http.client.HTTPMessage at 0x7b5121cc2e90>)"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["# define a wrapper function for seeing how prompts affect transcriptions\n","def transcribe(prompt: str, audio_filepath) -> str:\n"," \"\"\"Given a prompt, transcribe the audio file.\"\"\"\n"," transcript = client.audio.transcriptions.create(\n"," file=open(audio_filepath, \"rb\"),\n"," model=\"whisper-1\",\n"," prompt=prompt,\n"," )\n"," return transcript.text"],"metadata":{"id":"bBbfZ4hQTdHu","executionInfo":{"status":"ok","timestamp":1708848501454,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# baseline transcription with no prompt\n","transcribe(prompt=\"\", audio_filepath=ZyntriQix_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":106},"id":"N1Ai56QTTsez","executionInfo":{"status":"ok","timestamp":1708848508182,"user_tz":-480,"elapsed":2968,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"f215b6a3-3aa6-43d8-ddf5-d5852af507ea"},"execution_count":8,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZentricX? This tech giant boasts products like Digi-Q+, Synapse 5, VortiCore V8, Echo Nix Array, and not to forget the latest Orbital Link 7 and Digifractal Matrix. Their innovation arsenal also includes the Pulse framework, Wrapped system, they've developed a brick infrastructure court system, and launched the Flint initiative, all highlighting their commitment to relentless innovation. ZentricX, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":8}]},{"cell_type":"code","source":["# add the correct spelling names to the prompt\n","transcribe(\n"," prompt=\"ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\",\n"," audio_filepath=ZyntriQix_filepath,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":107},"id":"An3B2Y9zTt2O","executionInfo":{"status":"ok","timestamp":1708848518387,"user_tz":-480,"elapsed":2814,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"d322550d-9f1a-414e-d500-aff25d458555"},"execution_count":9,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system. They've developed a B.R.I.C.K. infrastructure, Q.U.A.R.T. system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix in just 30 years has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":9}]},{"cell_type":"code","source":["# add a full product list to the prompt\n","transcribe(\n"," prompt=\"ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\",\n"," audio_filepath=ZyntriQix_filepath,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":107},"id":"BigHfdf2TwWz","executionInfo":{"status":"ok","timestamp":1708848563186,"user_tz":-480,"elapsed":2891,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"49d73053-e880-4c81-8ec0-501dcb96886d"},"execution_count":10,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZentricX? This tech giant boasts products like DigiCube Plus, Synapse 5, VortiCore V8, EchoNix Array, and not to forget the latest Orbital Link 7 and Digifractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system. They've developed a brick infrastructure court system and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZentricX in just 30 years has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":10}]},{"cell_type":"code","source":["# define a wrapper function for seeing how prompts affect transcriptions\n","def transcribe_with_spellcheck(system_message, audio_filepath):\n"," completion = client.chat.completions.create(\n"," model=\"gpt-4\",\n"," temperature=0,\n"," messages=[\n"," {\"role\": \"system\", \"content\": system_message},\n"," {\n"," \"role\": \"user\",\n"," \"content\": transcribe(prompt=\"\", audio_filepath=audio_filepath),\n"," },\n"," ],\n"," )\n"," return completion.choices[0].message.content"],"metadata":{"id":"qgAtUs7zT7Rp","executionInfo":{"status":"ok","timestamp":1708848881820,"user_tz":-480,"elapsed":436,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Y8yDB69gVJu-","executionInfo":{"status":"ok","timestamp":1708848973250,"user_tz":-480,"elapsed":7268,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cb4154b0-4663-4104-d581-918a28de1abb"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, GigaLink TwentySeven, FusionMatrix TwentyEight, InfiniFractal TwentyNine, MetaSync Thirty, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T. Only add necessary punctuation such as periods, commas, and capitalization, and use only the context provided.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"RPQSOTYZVeUw","executionInfo":{"status":"ok","timestamp":1708849042978,"user_tz":-480,"elapsed":7098,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"a2ef0200-9128-4bd2-d32d-55091c688caa"},"execution_count":15,"outputs":[{"output_type":"stream","name":"stdout","text":["Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your first task is to list the words that are not spelled correctly according to the list provided to you and to tell me the number of misspelled words. Your next task is to insert those correct words in place of the misspelled ones. List: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, GigaLink TwentySeven, FusionMatrix TwentyEight, InfiniFractal TwentyNine, MetaSync Thirty, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XFaPTXDnVvdb","executionInfo":{"status":"ok","timestamp":1708849070064,"user_tz":-480,"elapsed":11266,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"019bf4fa-eac2-4e62-e7ab-88d2a01e150c"},"execution_count":16,"outputs":[{"output_type":"stream","name":"stdout","text":["The misspelled words are: ZentricX, Digi-Q+, Synapse 5, VortiCore V8, Echo Nix Array, Orbital Link 7, Digifractal Matrix, Pulse, Wrapped, brick, Flint. The number of misspelled words is 11.\n","\n","The corrected paragraph is:\n","\n","Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"F9m8pO76V0_3"},"execution_count":null,"outputs":[]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOUYiHanYAsDKHOE8u614e+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install openai"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"coTGz0rOTQgc","executionInfo":{"status":"ok","timestamp":1708848399555,"user_tz":-480,"elapsed":8409,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"68f777e2-463f-4184-9607-acd39df74922"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/226.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.4/226.7 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.27.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.2)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.4-py3-none-any.whl (77 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.8/77.8 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","Successfully installed h11-0.14.0 httpcore-1.0.4 httpx-0.27.0 openai-1.12.0\n"]}]},{"cell_type":"code","execution_count":4,"metadata":{"id":"zXxzVeSkS5L0","executionInfo":{"status":"ok","timestamp":1708848410118,"user_tz":-480,"elapsed":5242,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"outputs":[],"source":["# imports\n","from openai import OpenAI # for making OpenAI API calls\n","import urllib # for downloading example audio files\n","import os # for accessing environment variables\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"]},{"cell_type":"code","source":["# set download paths\n","ZyntriQix_remote_filepath = \"https://cdn.openai.com/API/examples/data/ZyntriQix.wav\"\n","\n","\n","# set local save locations\n","ZyntriQix_filepath = \"ZyntriQix.wav\"\n","\n","# download example audio files and save locally\n","urllib.request.urlretrieve(ZyntriQix_remote_filepath, ZyntriQix_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BPlMaAdhTOIX","executionInfo":{"status":"ok","timestamp":1708848430837,"user_tz":-480,"elapsed":410,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cc0217f6-7606-46fe-b2e8-7796078149b4"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('ZyntriQix.wav', <http.client.HTTPMessage at 0x7b5121cc2e90>)"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["# define a wrapper function for seeing how prompts affect transcriptions\n","def transcribe(prompt: str, audio_filepath) -> str:\n"," \"\"\"Given a prompt, transcribe the audio file.\"\"\"\n"," transcript = client.audio.transcriptions.create(\n"," file=open(audio_filepath, \"rb\"),\n"," model=\"whisper-1\",\n"," prompt=prompt,\n"," )\n"," return transcript.text"],"metadata":{"id":"bBbfZ4hQTdHu","executionInfo":{"status":"ok","timestamp":1708848501454,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# baseline transcription with no prompt\n","transcribe(prompt=\"\", audio_filepath=ZyntriQix_filepath)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":106},"id":"N1Ai56QTTsez","executionInfo":{"status":"ok","timestamp":1708848508182,"user_tz":-480,"elapsed":2968,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"f215b6a3-3aa6-43d8-ddf5-d5852af507ea"},"execution_count":8,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZentricX? This tech giant boasts products like Digi-Q+, Synapse 5, VortiCore V8, Echo Nix Array, and not to forget the latest Orbital Link 7 and Digifractal Matrix. Their innovation arsenal also includes the Pulse framework, Wrapped system, they've developed a brick infrastructure court system, and launched the Flint initiative, all highlighting their commitment to relentless innovation. ZentricX, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":8}]},{"cell_type":"code","source":["# add the correct spelling names to the prompt\n","transcribe(\n"," prompt=\"ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\",\n"," audio_filepath=ZyntriQix_filepath,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":107},"id":"An3B2Y9zTt2O","executionInfo":{"status":"ok","timestamp":1708848518387,"user_tz":-480,"elapsed":2814,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"d322550d-9f1a-414e-d500-aff25d458555"},"execution_count":9,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system. They've developed a B.R.I.C.K. infrastructure, Q.U.A.R.T. system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix in just 30 years has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":9}]},{"cell_type":"code","source":["# add a full product list to the prompt\n","transcribe(\n"," prompt=\"ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\",\n"," audio_filepath=ZyntriQix_filepath,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":107},"id":"BigHfdf2TwWz","executionInfo":{"status":"ok","timestamp":1708848563186,"user_tz":-480,"elapsed":2891,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"49d73053-e880-4c81-8ec0-501dcb96886d"},"execution_count":10,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"Have you heard of ZentricX? This tech giant boasts products like DigiCube Plus, Synapse 5, VortiCore V8, EchoNix Array, and not to forget the latest Orbital Link 7 and Digifractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system. They've developed a brick infrastructure court system and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZentricX in just 30 years has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":10}]},{"cell_type":"code","source":["# define a wrapper function for seeing how prompts affect transcriptions\n","def transcribe_with_spellcheck(system_message, audio_filepath):\n"," completion = client.chat.completions.create(\n"," model=\"gpt-4\",\n"," temperature=0,\n"," messages=[\n"," {\"role\": \"system\", \"content\": system_message},\n"," {\n"," \"role\": \"user\",\n"," \"content\": transcribe(prompt=\"\", audio_filepath=audio_filepath),\n"," },\n"," ],\n"," )\n"," return completion.choices[0].message.content"],"metadata":{"id":"qgAtUs7zT7Rp","executionInfo":{"status":"ok","timestamp":1708848881820,"user_tz":-480,"elapsed":436,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Y8yDB69gVJu-","executionInfo":{"status":"ok","timestamp":1708848973250,"user_tz":-480,"elapsed":7268,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cb4154b0-4663-4104-d581-918a28de1abb"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, GigaLink TwentySeven, FusionMatrix TwentyEight, InfiniFractal TwentyNine, MetaSync Thirty, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T. Only add necessary punctuation such as periods, commas, and capitalization, and use only the context provided.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"RPQSOTYZVeUw","executionInfo":{"status":"ok","timestamp":1708849042978,"user_tz":-480,"elapsed":7098,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"a2ef0200-9128-4bd2-d32d-55091c688caa"},"execution_count":15,"outputs":[{"output_type":"stream","name":"stdout","text":["Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":["system_prompt = \"You are a helpful assistant for the company ZyntriQix. Your first task is to list the words that are not spelled correctly according to the list provided to you and to tell me the number of misspelled words. Your next task is to insert those correct words in place of the misspelled ones. List: ZyntriQix, Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, OrbitalLink Seven, DigiFractal Matrix, PULSE, RAPT, AstroPixel Array, QuantumFlare Five, CyberPulse Six, VortexDrive Matrix, PhotonLink Ten, TriCircuit Array, PentaSync Seven, UltraWave Eight, QuantumVertex Nine, HyperHelix X, DigiSpiral Z, PentaQuark Eleven, TetraCube Twelve, GigaPhase Thirteen, EchoNeuron Fourteen, FusionPulse V15, MetaQuark Sixteen, InfiniCircuit Seventeen, TeraPulse Eighteen, ExoMatrix Nineteen, OrbiSync Twenty, QuantumHelix TwentyOne, NanoPhase TwentyTwo, TeraFractal TwentyThree, PentaHelix TwentyFour, ExoCircuit TwentyFive, HyperQuark TwentySix, GigaLink TwentySeven, FusionMatrix TwentyEight, InfiniFractal TwentyNine, MetaSync Thirty, B.R.I.C.K., Q.U.A.R.T.Z., F.L.I.N.T.\"\n","new_text = transcribe_with_spellcheck(system_prompt, audio_filepath=ZyntriQix_filepath)\n","print(new_text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XFaPTXDnVvdb","executionInfo":{"status":"ok","timestamp":1708849070064,"user_tz":-480,"elapsed":11266,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"019bf4fa-eac2-4e62-e7ab-88d2a01e150c"},"execution_count":16,"outputs":[{"output_type":"stream","name":"stdout","text":["The misspelled words are: ZentricX, Digi-Q+, Synapse 5, VortiCore V8, Echo Nix Array, Orbital Link 7, Digifractal Matrix, Pulse, Wrapped, brick, Flint. The number of misspelled words is 11.\n","\n","The corrected paragraph is:\n","\n","Have you heard of ZyntriQix? This tech giant boasts products like Digique Plus, CynapseFive, VortiQore V8, EchoNix Array, and not to forget the latest OrbitalLink Seven and DigiFractal Matrix. Their innovation arsenal also includes the PULSE framework, RAPT system, they've developed a B.R.I.C.K. infrastructure court system, and launched the F.L.I.N.T. initiative, all highlighting their commitment to relentless innovation. ZyntriQix, in just 30 years, has soared from a startup to a tech titan, serving us tech marvels alongside a stimulating linguistic challenge. Quite an adventure, wouldn't you agree?\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"F9m8pO76V0_3"},"execution_count":null,"outputs":[]}]}
\ No newline at end of file
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNq4FxI1Nnt9wQ1OIsvfhhf"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_HFSwgSPHhuW","executionInfo":{"status":"ok","timestamp":1707620627529,"user_tz":-480,"elapsed":14418,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cc26d2b5-bf67-4dac-dd1f-6a7b0ae3c5d3"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.26.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.1)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.2-py3-none-any.whl (76 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","llmx 0.0.15a0 requires cohere, which is not installed.\n","llmx 0.0.15a0 requires tiktoken, which is not installed.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed h11-0.14.0 httpcore-1.0.2 httpx-0.26.0 openai-1.12.0\n"]}],"source":["!pip install openai"]},{"cell_type":"code","source":["# imports\n","from openai import OpenAI # OpenAI Python library to make API calls\n","import requests # used to download images\n","import os # used to access filepaths\n","from PIL import Image # used to print and edit images\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","\n","# initialize OpenAI client\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"],"metadata":{"id":"LWChbfHZHpqE","executionInfo":{"status":"ok","timestamp":1707620664529,"user_tz":-480,"elapsed":4986,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["# set a directory to save DALL·E images to\n","image_dir_name = \"images\"\n","image_dir = os.path.join(os.curdir, image_dir_name)\n","\n","# create the directory if it doesn't yet exist\n","if not os.path.isdir(image_dir):\n"," os.mkdir(image_dir)\n","\n","# print the directory to save to\n","print(f\"{image_dir=}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"iUVZeMepID6g","executionInfo":{"status":"ok","timestamp":1707620724775,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3a8d7670-7fb8-4014-d992-04661fcd4838"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["image_dir='./images'\n"]}]},{"cell_type":"code","source":["# create an image\n","\n","# set the prompt\n","prompt = \"A cyberpunk monkey hacker dreaming of a beautiful bunch of bananas, digital art\"\n","\n","# call the OpenAI API\n","generation_response = client.images.generate(\n"," model = \"dall-e-3\",\n"," prompt=prompt,\n"," n=1,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(generation_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YZ2cKldyIJHA","executionInfo":{"status":"ok","timestamp":1707620790602,"user_tz":-480,"elapsed":12934,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3d6c08ad-c50c-4784-9e0f-c1c712ca31b5"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620790, data=[Image(b64_json=None, revised_prompt='Visualize an Afro-futurist scene featuring a Cyberpunk monkey of South Asian descent. This tech-savvy primate is known to be a brilliant hacker in the dystopian cityscape. Engrossed in its computer interfaces, the monkey is caught in a digital reverie, dreaming about a beautiful bunch of exotic bananas glowing with a neon shimmer. As the monkey visualizes its dream, the bananas take on an alluring holographic quality, floating in the air with hypnotic allure. This piece is to be executed in digital art format, showcasing a harmonious blend of the elements of high technology and nature.', url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-2vOgPGXhAsvIe91fVVxqspiK.png?st=2024-02-11T02%3A06%3A30Z&se=2024-02-11T04%3A06%3A30Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T02%3A48%3A08Z&ske=2024-02-12T02%3A48%3A08Z&sks=b&skv=2021-08-06&sig=6PDnrkawwpcsdyQxt1jy7GlsCHGDNQFh8jqZ3YU2K8U%3D')])\n"]}]},{"cell_type":"code","source":["# save the image\n","generated_image_name = \"generated_image.png\" # any name you like; the filetype should be .png\n","generated_image_filepath = os.path.join(image_dir, generated_image_name)\n","generated_image_url = generation_response.data[0].url # extract image URL from response\n","generated_image = requests.get(generated_image_url).content # download the image\n","\n","with open(generated_image_filepath, \"wb\") as image_file:\n"," image_file.write(generated_image) # write the image to the file"],"metadata":{"id":"Trb0VCmYIV_O","executionInfo":{"status":"ok","timestamp":1707620795969,"user_tz":-480,"elapsed":1031,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# print the image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1A5civNQ1IlgQa15_NuTRULxgEE5m_anm"},"id":"zwA6a2vcIZJU","executionInfo":{"status":"ok","timestamp":1707620808046,"user_tz":-480,"elapsed":3110,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"8f6329d8-f643-46b6-a032-f79610ee2dd9"},"execution_count":8,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"code","source":["# create variations\n","\n","# call the OpenAI API, using `create_variation` rather than `create`\n","variation_response = client.images.create_variation(\n"," image=generated_image, # generated_image is the image generated above\n"," n=2,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(variation_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aNVfp-x-IdR5","executionInfo":{"status":"ok","timestamp":1707620847302,"user_tz":-480,"elapsed":11237,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"ceb4a5b0-5929-4cc3-eee1-fe72703f57f0"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620846, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-C6Nxl1z3nADMKOzFNP5PJjOX.png?st=2024-02-11T02%3A07%3A26Z&se=2024-02-11T04%3A07%3A26Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A15%3A53Z&ske=2024-02-12T00%3A15%3A53Z&sks=b&skv=2021-08-06&sig=zODK%2BIRqp1UYplmmxU5qmK4d8EI/fCULNI/itMnhB9M%3D'), Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-6ONkUfz0N4o1QQPRIIJSn0Q5.png?st=2024-02-11T02%3A07%3A26Z&se=2024-02-11T04%3A07%3A26Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A15%3A53Z&ske=2024-02-12T00%3A15%3A53Z&sks=b&skv=2021-08-06&sig=0t1oQBsbVY/5isXdc1IwnIIy5yNOrz3sO9vtd0%2BXPy0%3D')])\n"]}]},{"cell_type":"code","source":["# save the images\n","variation_urls = [datum.url for datum in variation_response.data] # extract URLs\n","variation_images = [requests.get(url).content for url in variation_urls] # download images\n","variation_image_names = [f\"variation_image_{i}.png\" for i in range(len(variation_images))] # create names\n","variation_image_filepaths = [os.path.join(image_dir, name) for name in variation_image_names] # create filepaths\n","for image, filepath in zip(variation_images, variation_image_filepaths): # loop through the variations\n"," with open(filepath, \"wb\") as image_file: # open the file\n"," image_file.write(image) # write the image to the file"],"metadata":{"id":"_N5I6-MbIj_k","executionInfo":{"status":"ok","timestamp":1707620853326,"user_tz":-480,"elapsed":1547,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["# print the original image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))\n","\n","# print the new variations\n","for variation_image_filepaths in variation_image_filepaths:\n"," print(variation_image_filepaths)\n"," display(Image.open(variation_image_filepaths))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1H1TewHJPe2gJL-Eaj4NEORy5lQ85NRiV"},"id":"21vwB0v3ImVD","executionInfo":{"status":"ok","timestamp":1707620866250,"user_tz":-480,"elapsed":6324,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"7faf9d35-1767-47f0-9455-e348e17768da"},"execution_count":11,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"code","source":["# create a mask\n","width = 1024\n","height = 1024\n","mask = Image.new(\"RGBA\", (width, height), (0, 0, 0, 1)) # create an opaque image mask\n","\n","# set the bottom half to be transparent\n","for x in range(width):\n"," for y in range(height // 2, height): # only loop over the bottom half of the mask\n"," # set alpha (A) to zero to turn pixel transparent\n"," alpha = 0\n"," mask.putpixel((x, y), (0, 0, 0, alpha))\n","\n","# save the mask\n","mask_name = \"bottom_half_mask.png\"\n","mask_filepath = os.path.join(image_dir, mask_name)\n","mask.save(mask_filepath)"],"metadata":{"id":"B2sRioF_IsOo","executionInfo":{"status":"ok","timestamp":1707620887886,"user_tz":-480,"elapsed":725,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["# edit an image\n","\n","# call the OpenAI API\n","edit_response = client.images.edit(\n"," image=open(generated_image_filepath, \"rb\"), # from the generation section\n"," mask=open(mask_filepath, \"rb\"), # from right above\n"," prompt=prompt, # from the generation section\n"," n=1,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(edit_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wMjNaSHrIwnn","executionInfo":{"status":"ok","timestamp":1707620915324,"user_tz":-480,"elapsed":12986,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"30355f95-6b81-4891-a38f-4e6d77047343"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620914, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-QAChvyL9dblsL2mlnQoSB2AF.png?st=2024-02-11T02%3A08%3A34Z&se=2024-02-11T04%3A08%3A34Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A55%3A39Z&ske=2024-02-12T00%3A55%3A39Z&sks=b&skv=2021-08-06&sig=%2BLPO7RbTdlbGeR%2BugQsNbfKChhLETDRXnP24OCCP/sI%3D')])\n"]}]},{"cell_type":"code","source":["# save the image\n","edited_image_name = \"edited_image.png\" # any name you like; the filetype should be .png\n","edited_image_filepath = os.path.join(image_dir, edited_image_name)\n","edited_image_url = edit_response.data[0].url # extract image URL from response\n","edited_image = requests.get(edited_image_url).content # download the image\n","\n","with open(edited_image_filepath, \"wb\") as image_file:\n"," image_file.write(edited_image) # write the image to the file\n","# print the original image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))\n","\n","# print edited image\n","print(edited_image_filepath)\n","display(Image.open(edited_image_filepath))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1ALB0HUAW5L0MesfvRQ-txqrJpFHZ_8YJ"},"id":"HRFp8qGKIy8D","executionInfo":{"status":"ok","timestamp":1707620921356,"user_tz":-480,"elapsed":6034,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"0b058d3c-e55a-4cda-b928-2f4a4ec6b27c"},"execution_count":14,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNq4FxI1Nnt9wQ1OIsvfhhf"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_HFSwgSPHhuW","executionInfo":{"status":"ok","timestamp":1707620627529,"user_tz":-480,"elapsed":14418,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"cc26d2b5-bf67-4dac-dd1f-6a7b0ae3c5d3"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.26.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.1)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.2-py3-none-any.whl (76 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","llmx 0.0.15a0 requires cohere, which is not installed.\n","llmx 0.0.15a0 requires tiktoken, which is not installed.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed h11-0.14.0 httpcore-1.0.2 httpx-0.26.0 openai-1.12.0\n"]}],"source":["!pip install openai"]},{"cell_type":"code","source":["# imports\n","from openai import OpenAI # OpenAI Python library to make API calls\n","import requests # used to download images\n","import os # used to access filepaths\n","from PIL import Image # used to print and edit images\n","\n","from google.colab import userdata\n","key = userdata.get('OpenAI-Key')\n","\n","# initialize OpenAI client\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", key))"],"metadata":{"id":"LWChbfHZHpqE","executionInfo":{"status":"ok","timestamp":1707620664529,"user_tz":-480,"elapsed":4986,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["# set a directory to save DALL·E images to\n","image_dir_name = \"images\"\n","image_dir = os.path.join(os.curdir, image_dir_name)\n","\n","# create the directory if it doesn't yet exist\n","if not os.path.isdir(image_dir):\n"," os.mkdir(image_dir)\n","\n","# print the directory to save to\n","print(f\"{image_dir=}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"iUVZeMepID6g","executionInfo":{"status":"ok","timestamp":1707620724775,"user_tz":-480,"elapsed":3,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3a8d7670-7fb8-4014-d992-04661fcd4838"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["image_dir='./images'\n"]}]},{"cell_type":"code","source":["# create an image\n","\n","# set the prompt\n","prompt = \"A cyberpunk monkey hacker dreaming of a beautiful bunch of bananas, digital art\"\n","\n","# call the OpenAI API\n","generation_response = client.images.generate(\n"," model = \"dall-e-3\",\n"," prompt=prompt,\n"," n=1,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(generation_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YZ2cKldyIJHA","executionInfo":{"status":"ok","timestamp":1707620790602,"user_tz":-480,"elapsed":12934,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3d6c08ad-c50c-4784-9e0f-c1c712ca31b5"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620790, data=[Image(b64_json=None, revised_prompt='Visualize an Afro-futurist scene featuring a Cyberpunk monkey of South Asian descent. This tech-savvy primate is known to be a brilliant hacker in the dystopian cityscape. Engrossed in its computer interfaces, the monkey is caught in a digital reverie, dreaming about a beautiful bunch of exotic bananas glowing with a neon shimmer. As the monkey visualizes its dream, the bananas take on an alluring holographic quality, floating in the air with hypnotic allure. This piece is to be executed in digital art format, showcasing a harmonious blend of the elements of high technology and nature.', url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-2vOgPGXhAsvIe91fVVxqspiK.png?st=2024-02-11T02%3A06%3A30Z&se=2024-02-11T04%3A06%3A30Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T02%3A48%3A08Z&ske=2024-02-12T02%3A48%3A08Z&sks=b&skv=2021-08-06&sig=6PDnrkawwpcsdyQxt1jy7GlsCHGDNQFh8jqZ3YU2K8U%3D')])\n"]}]},{"cell_type":"code","source":["# save the image\n","generated_image_name = \"generated_image.png\" # any name you like; the filetype should be .png\n","generated_image_filepath = os.path.join(image_dir, generated_image_name)\n","generated_image_url = generation_response.data[0].url # extract image URL from response\n","generated_image = requests.get(generated_image_url).content # download the image\n","\n","with open(generated_image_filepath, \"wb\") as image_file:\n"," image_file.write(generated_image) # write the image to the file"],"metadata":{"id":"Trb0VCmYIV_O","executionInfo":{"status":"ok","timestamp":1707620795969,"user_tz":-480,"elapsed":1031,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# print the image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1A5civNQ1IlgQa15_NuTRULxgEE5m_anm"},"id":"zwA6a2vcIZJU","executionInfo":{"status":"ok","timestamp":1707620808046,"user_tz":-480,"elapsed":3110,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"8f6329d8-f643-46b6-a032-f79610ee2dd9"},"execution_count":8,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"code","source":["# create variations\n","\n","# call the OpenAI API, using `create_variation` rather than `create`\n","variation_response = client.images.create_variation(\n"," image=generated_image, # generated_image is the image generated above\n"," n=2,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(variation_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aNVfp-x-IdR5","executionInfo":{"status":"ok","timestamp":1707620847302,"user_tz":-480,"elapsed":11237,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"ceb4a5b0-5929-4cc3-eee1-fe72703f57f0"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620846, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-C6Nxl1z3nADMKOzFNP5PJjOX.png?st=2024-02-11T02%3A07%3A26Z&se=2024-02-11T04%3A07%3A26Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A15%3A53Z&ske=2024-02-12T00%3A15%3A53Z&sks=b&skv=2021-08-06&sig=zODK%2BIRqp1UYplmmxU5qmK4d8EI/fCULNI/itMnhB9M%3D'), Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-6ONkUfz0N4o1QQPRIIJSn0Q5.png?st=2024-02-11T02%3A07%3A26Z&se=2024-02-11T04%3A07%3A26Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A15%3A53Z&ske=2024-02-12T00%3A15%3A53Z&sks=b&skv=2021-08-06&sig=0t1oQBsbVY/5isXdc1IwnIIy5yNOrz3sO9vtd0%2BXPy0%3D')])\n"]}]},{"cell_type":"code","source":["# save the images\n","variation_urls = [datum.url for datum in variation_response.data] # extract URLs\n","variation_images = [requests.get(url).content for url in variation_urls] # download images\n","variation_image_names = [f\"variation_image_{i}.png\" for i in range(len(variation_images))] # create names\n","variation_image_filepaths = [os.path.join(image_dir, name) for name in variation_image_names] # create filepaths\n","for image, filepath in zip(variation_images, variation_image_filepaths): # loop through the variations\n"," with open(filepath, \"wb\") as image_file: # open the file\n"," image_file.write(image) # write the image to the file"],"metadata":{"id":"_N5I6-MbIj_k","executionInfo":{"status":"ok","timestamp":1707620853326,"user_tz":-480,"elapsed":1547,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["# print the original image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))\n","\n","# print the new variations\n","for variation_image_filepaths in variation_image_filepaths:\n"," print(variation_image_filepaths)\n"," display(Image.open(variation_image_filepaths))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1H1TewHJPe2gJL-Eaj4NEORy5lQ85NRiV"},"id":"21vwB0v3ImVD","executionInfo":{"status":"ok","timestamp":1707620866250,"user_tz":-480,"elapsed":6324,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"7faf9d35-1767-47f0-9455-e348e17768da"},"execution_count":11,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"code","source":["# create a mask\n","width = 1024\n","height = 1024\n","mask = Image.new(\"RGBA\", (width, height), (0, 0, 0, 1)) # create an opaque image mask\n","\n","# set the bottom half to be transparent\n","for x in range(width):\n"," for y in range(height // 2, height): # only loop over the bottom half of the mask\n"," # set alpha (A) to zero to turn pixel transparent\n"," alpha = 0\n"," mask.putpixel((x, y), (0, 0, 0, alpha))\n","\n","# save the mask\n","mask_name = \"bottom_half_mask.png\"\n","mask_filepath = os.path.join(image_dir, mask_name)\n","mask.save(mask_filepath)"],"metadata":{"id":"B2sRioF_IsOo","executionInfo":{"status":"ok","timestamp":1707620887886,"user_tz":-480,"elapsed":725,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["# edit an image\n","\n","# call the OpenAI API\n","edit_response = client.images.edit(\n"," image=open(generated_image_filepath, \"rb\"), # from the generation section\n"," mask=open(mask_filepath, \"rb\"), # from right above\n"," prompt=prompt, # from the generation section\n"," n=1,\n"," size=\"1024x1024\",\n"," response_format=\"url\",\n",")\n","\n","# print response\n","print(edit_response)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wMjNaSHrIwnn","executionInfo":{"status":"ok","timestamp":1707620915324,"user_tz":-480,"elapsed":12986,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"30355f95-6b81-4891-a38f-4e6d77047343"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["ImagesResponse(created=1707620914, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-kyVYbgMiqrBxWtGyhn65YxKS/user-mClH64GZgKtTJE1PDMG60x5y/img-QAChvyL9dblsL2mlnQoSB2AF.png?st=2024-02-11T02%3A08%3A34Z&se=2024-02-11T04%3A08%3A34Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-02-11T00%3A55%3A39Z&ske=2024-02-12T00%3A55%3A39Z&sks=b&skv=2021-08-06&sig=%2BLPO7RbTdlbGeR%2BugQsNbfKChhLETDRXnP24OCCP/sI%3D')])\n"]}]},{"cell_type":"code","source":["# save the image\n","edited_image_name = \"edited_image.png\" # any name you like; the filetype should be .png\n","edited_image_filepath = os.path.join(image_dir, edited_image_name)\n","edited_image_url = edit_response.data[0].url # extract image URL from response\n","edited_image = requests.get(edited_image_url).content # download the image\n","\n","with open(edited_image_filepath, \"wb\") as image_file:\n"," image_file.write(edited_image) # write the image to the file\n","# print the original image\n","print(generated_image_filepath)\n","display(Image.open(generated_image_filepath))\n","\n","# print edited image\n","print(edited_image_filepath)\n","display(Image.open(edited_image_filepath))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1ALB0HUAW5L0MesfvRQ-txqrJpFHZ_8YJ"},"id":"HRFp8qGKIy8D","executionInfo":{"status":"ok","timestamp":1707620921356,"user_tz":-480,"elapsed":6034,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"0b058d3c-e55a-4cda-b928-2f4a4ec6b27c"},"execution_count":14,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]}]}
\ No newline at end of file
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPIJzTZmtW8FvyfCtddaeI7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"GaYCNssju8c8"},"outputs":[],"source":["!pip install timm\n","!pip install transformers"]},{"cell_type":"code","source":["import os\n","import cv2\n","import gc\n","import numpy as np\n","import pandas as pd\n","import itertools\n","from tqdm.autonotebook import tqdm\n","import albumentations as A\n","import matplotlib.pyplot as plt\n","\n","import torch\n","from torch import nn\n","import torch.nn.functional as F\n","import timm\n","from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer"],"metadata":{"id":"XXebCRIUvpQk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!pip install kaggle --upgrade\n","# change your name and key\n","os.environ['KAGGLE_USERNAME'] = \"XXXXX\"\n","os.environ['KAGGLE_KEY'] = \"XXXXXXXXXXXXXX\"\n","\n","### For Flickr 8k\n","!kaggle datasets download -d adityajn105/flickr8k\n","!unzip flickr8k.zip\n","dataset = \"8k\"\n","\n","\n","### For Flickr 30k\n","# !kaggle datasets download -d hsankesara/flickr-image-dataset\n","# !unzip flickr-image-dataset.zip\n","# dataset = \"30k\""],"metadata":{"id":"Y1XOe4mfvq2p"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["if dataset == \"8k\":\n"," df = pd.read_csv(\"captions.txt\")\n"," df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]\n"," df.to_csv(\"captions.csv\", index=False)\n"," df = pd.read_csv(\"captions.csv\")\n"," image_path = \"/content/Images\"\n"," captions_path = \"/content\"\n","elif dataset == \"30k\":\n"," df = pd.read_csv(\"/content/flickr30k_images/results.csv\", delimiter=\"|\")\n"," df.columns = ['image', 'caption_number', 'caption']\n"," df['caption'] = df['caption'].str.lstrip()\n"," df['caption_number'] = df['caption_number'].str.lstrip()\n"," df.loc[19999, 'caption_number'] = \"4\"\n"," df.loc[19999, 'caption'] = \"A dog runs across the grass .\"\n"," ids = [id_ for id_ in range(len(df) // 5) for _ in range(5)]\n"," df['id'] = ids\n"," df.to_csv(\"captions.csv\", index=False)\n"," image_path = \"/content/flickr30k_images/flickr30k_images\"\n"," captions_path = \"/content\"\n","\n","df.head()"],"metadata":{"id":"LvS448T4vwHF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CFG:\n"," debug = False\n"," image_path = image_path\n"," captions_path = captions_path\n"," batch_size = 32\n"," num_workers = 2\n"," head_lr = 1e-3\n"," image_encoder_lr = 1e-4\n"," text_encoder_lr = 1e-5\n"," weight_decay = 1e-3\n"," patience = 1\n"," factor = 0.8\n"," epochs = 4\n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n"," model_name = 'resnet50'\n"," image_embedding = 2048\n"," text_encoder_model = \"distilbert-base-uncased\"\n"," text_embedding = 768\n"," text_tokenizer = \"distilbert-base-uncased\"\n"," max_length = 200\n","\n"," pretrained = True # for both image encoder and text encoder\n"," trainable = True # for both image encoder and text encoder\n"," temperature = 1.0\n","\n"," # image size\n"," size = 224\n","\n"," # for projection head; used for both image and text encoders\n"," num_projection_layers = 1\n"," projection_dim = 256\n"," dropout = 0.1"],"metadata":{"id":"LIOAspaYwF_N"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class AvgMeter:\n"," def __init__(self, name=\"Metric\"):\n"," self.name = name\n"," self.reset()\n","\n"," def reset(self):\n"," self.avg, self.sum, self.count = [0] * 3\n","\n"," def update(self, val, count=1):\n"," self.count += count\n"," self.sum += val * count\n"," self.avg = self.sum / self.count\n","\n"," def __repr__(self):\n"," text = f\"{self.name}: {self.avg:.4f}\"\n"," return text\n","\n","def get_lr(optimizer):\n"," for param_group in optimizer.param_groups:\n"," return param_group[\"lr\"]"],"metadata":{"id":"HM_QqFlqwH4a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CLIPDataset(torch.utils.data.Dataset):\n"," def __init__(self, image_filenames, captions, tokenizer, transforms):\n"," \"\"\"\n"," image_filenames and cpations must have the same length; so, if there are\n"," multiple captions for each image, the image_filenames must have repetitive\n"," file names\n"," \"\"\"\n","\n"," self.image_filenames = image_filenames\n"," self.captions = list(captions)\n"," self.encoded_captions = tokenizer(\n"," list(captions), padding=True, truncation=True, max_length=CFG.max_length\n"," )\n"," self.transforms = transforms\n","\n"," def __getitem__(self, idx):\n"," item = {\n"," key: torch.tensor(values[idx])\n"," for key, values in self.encoded_captions.items()\n"," }\n","\n"," image = cv2.imread(f\"{CFG.image_path}/{self.image_filenames[idx]}\")\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image = self.transforms(image=image)['image']\n"," item['image'] = torch.tensor(image).permute(2, 0, 1).float()\n"," item['caption'] = self.captions[idx]\n","\n"," return item\n","\n","\n"," def __len__(self):\n"," return len(self.captions)\n","\n","\n","\n","def get_transforms(mode=\"train\"):\n"," if mode == \"train\":\n"," return A.Compose(\n"," [\n"," A.Resize(CFG.size, CFG.size, always_apply=True),\n"," A.Normalize(max_pixel_value=255.0, always_apply=True),\n"," ]\n"," )\n"," else:\n"," return A.Compose(\n"," [\n"," A.Resize(CFG.size, CFG.size, always_apply=True),\n"," A.Normalize(max_pixel_value=255.0, always_apply=True),\n"," ]\n"," )"],"metadata":{"id":"w2JcEE6gwJu6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ImageEncoder(nn.Module):\n"," \"\"\"\n"," Encode images to a fixed size vector\n"," \"\"\"\n","\n"," def __init__(\n"," self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable\n"," ):\n"," super().__init__()\n"," self.model = timm.create_model(\n"," model_name, pretrained, num_classes=0, global_pool=\"avg\"\n"," )\n"," for p in self.model.parameters():\n"," p.requires_grad = trainable\n","\n"," def forward(self, x):\n"," return self.model(x)"],"metadata":{"id":"mrZ4jsXAwONW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TextEncoder(nn.Module):\n"," def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):\n"," super().__init__()\n"," if pretrained:\n"," self.model = DistilBertModel.from_pretrained(model_name)\n"," else:\n"," self.model = DistilBertModel(config=DistilBertConfig())\n","\n"," for p in self.model.parameters():\n"," p.requires_grad = trainable\n","\n"," # we are using the CLS token hidden representation as the sentence's embedding\n"," self.target_token_idx = 0\n","\n"," def forward(self, input_ids, attention_mask):\n"," output = self.model(input_ids=input_ids, attention_mask=attention_mask)\n"," last_hidden_state = output.last_hidden_state\n"," return last_hidden_state[:, self.target_token_idx, :]"],"metadata":{"id":"ju9a8O0JwQBx"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ProjectionHead(nn.Module):\n"," def __init__(\n"," self,\n"," embedding_dim,\n"," projection_dim=CFG.projection_dim,\n"," dropout=CFG.dropout\n"," ):\n"," super().__init__()\n"," self.projection = nn.Linear(embedding_dim, projection_dim)\n"," self.gelu = nn.GELU()\n"," self.fc = nn.Linear(projection_dim, projection_dim)\n"," self.dropout = nn.Dropout(dropout)\n"," self.layer_norm = nn.LayerNorm(projection_dim)\n","\n"," def forward(self, x):\n"," projected = self.projection(x)\n"," x = self.gelu(projected)\n"," x = self.fc(x)\n"," x = self.dropout(x)\n"," x = x + projected\n"," x = self.layer_norm(x)\n"," return x"],"metadata":{"id":"UXREwtR9wRpI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CLIPModel(nn.Module):\n"," def __init__(\n"," self,\n"," temperature=CFG.temperature,\n"," image_embedding=CFG.image_embedding,\n"," text_embedding=CFG.text_embedding,\n"," ):\n"," super().__init__()\n"," self.image_encoder = ImageEncoder()\n"," self.text_encoder = TextEncoder()\n"," self.image_projection = ProjectionHead(embedding_dim=image_embedding)\n"," self.text_projection = ProjectionHead(embedding_dim=text_embedding)\n"," self.temperature = temperature\n","\n"," def forward(self, batch):\n"," # Getting Image and Text Features\n"," image_features = self.image_encoder(batch[\"image\"])\n"," text_features = self.text_encoder(\n"," input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n"," )\n"," # Getting Image and Text Embeddings (with same dimension)\n"," image_embeddings = self.image_projection(image_features)\n"," text_embeddings = self.text_projection(text_features)\n","\n"," # Calculating the Loss\n"," logits = (text_embeddings @ image_embeddings.T) / self.temperature\n"," images_similarity = image_embeddings @ image_embeddings.T\n"," texts_similarity = text_embeddings @ text_embeddings.T\n"," targets = F.softmax(\n"," (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1\n"," )\n"," texts_loss = cross_entropy(logits, targets, reduction='none')\n"," images_loss = cross_entropy(logits.T, targets.T, reduction='none')\n"," loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)\n"," return loss.mean()\n","\n","\n","def cross_entropy(preds, targets, reduction='none'):\n"," log_softmax = nn.LogSoftmax(dim=-1)\n"," loss = (-targets * log_softmax(preds)).sum(1)\n"," if reduction == \"none\":\n"," return loss\n"," elif reduction == \"mean\":\n"," return loss.mean()"],"metadata":{"id":"1_fhZIC4wTQA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# A simple Example\n","\n","batch_size = 4\n","dim = 256\n","embeddings = torch.randn(batch_size, dim)\n","out = embeddings @ embeddings.T\n","print(F.softmax(out, dim=-1))"],"metadata":{"id":"ajzH97wGwVwD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def make_train_valid_dfs():\n"," dataframe = pd.read_csv(f\"{CFG.captions_path}/captions.csv\")\n"," max_id = dataframe[\"id\"].max() + 1 if not CFG.debug else 100\n"," image_ids = np.arange(0, max_id)\n"," np.random.seed(42)\n"," valid_ids = np.random.choice(\n"," image_ids, size=int(0.2 * len(image_ids)), replace=False\n"," )\n"," train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]\n"," train_dataframe = dataframe[dataframe[\"id\"].isin(train_ids)].reset_index(drop=True)\n"," valid_dataframe = dataframe[dataframe[\"id\"].isin(valid_ids)].reset_index(drop=True)\n"," return train_dataframe, valid_dataframe\n","\n","\n","def build_loaders(dataframe, tokenizer, mode):\n"," transforms = get_transforms(mode=mode)\n"," dataset = CLIPDataset(\n"," dataframe[\"image\"].values,\n"," dataframe[\"caption\"].values,\n"," tokenizer=tokenizer,\n"," transforms=transforms,\n"," )\n"," dataloader = torch.utils.data.DataLoader(\n"," dataset,\n"," batch_size=CFG.batch_size,\n"," num_workers=CFG.num_workers,\n"," shuffle=True if mode == \"train\" else False,\n"," )\n"," return dataloader"],"metadata":{"id":"mYNL9TCewXMG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def train_epoch(model, train_loader, optimizer, lr_scheduler, step):\n"," loss_meter = AvgMeter()\n"," tqdm_object = tqdm(train_loader, total=len(train_loader))\n"," for batch in tqdm_object:\n"," batch = {k: v.to(CFG.device) for k, v in batch.items() if k != \"caption\"}\n"," loss = model(batch)\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," if step == \"batch\":\n"," lr_scheduler.step()\n","\n"," count = batch[\"image\"].size(0)\n"," loss_meter.update(loss.item(), count)\n","\n"," tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))\n"," return loss_meter\n","\n","\n","def valid_epoch(model, valid_loader):\n"," loss_meter = AvgMeter()\n","\n"," tqdm_object = tqdm(valid_loader, total=len(valid_loader))\n"," for batch in tqdm_object:\n"," batch = {k: v.to(CFG.device) for k, v in batch.items() if k != \"caption\"}\n"," loss = model(batch)\n","\n"," count = batch[\"image\"].size(0)\n"," loss_meter.update(loss.item(), count)\n","\n"," tqdm_object.set_postfix(valid_loss=loss_meter.avg)\n"," return loss_meter\n","\n","\n","def main():\n"," train_df, valid_df = make_train_valid_dfs()\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," train_loader = build_loaders(train_df, tokenizer, mode=\"train\")\n"," valid_loader = build_loaders(valid_df, tokenizer, mode=\"valid\")\n","\n","\n"," model = CLIPModel().to(CFG.device)\n"," params = [\n"," {\"params\": model.image_encoder.parameters(), \"lr\": CFG.image_encoder_lr},\n"," {\"params\": model.text_encoder.parameters(), \"lr\": CFG.text_encoder_lr},\n"," {\"params\": itertools.chain(\n"," model.image_projection.parameters(), model.text_projection.parameters()\n"," ), \"lr\": CFG.head_lr, \"weight_decay\": CFG.weight_decay}\n"," ]\n"," optimizer = torch.optim.AdamW(params, weight_decay=0.)\n"," lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n"," optimizer, mode=\"min\", patience=CFG.patience, factor=CFG.factor\n"," )\n"," step = \"epoch\"\n","\n"," best_loss = float('inf')\n"," for epoch in range(CFG.epochs):\n"," print(f\"Epoch: {epoch + 1}\")\n"," model.train()\n"," train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)\n"," model.eval()\n"," with torch.no_grad():\n"," valid_loss = valid_epoch(model, valid_loader)\n","\n"," if valid_loss.avg < best_loss:\n"," best_loss = valid_loss.avg\n"," torch.save(model.state_dict(), \"best.pt\")\n"," print(\"Saved Best Model!\")\n","\n"," lr_scheduler.step(valid_loss.avg)"],"metadata":{"id":"ftucX7sqwYy8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_image_embeddings(valid_df, model_path):\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," valid_loader = build_loaders(valid_df, tokenizer, mode=\"valid\")\n","\n"," model = CLIPModel().to(CFG.device)\n"," model.load_state_dict(torch.load(model_path, map_location=CFG.device))\n"," model.eval()\n","\n"," valid_image_embeddings = []\n"," with torch.no_grad():\n"," for batch in tqdm(valid_loader):\n"," image_features = model.image_encoder(batch[\"image\"].to(CFG.device))\n"," image_embeddings = model.image_projection(image_features)\n"," valid_image_embeddings.append(image_embeddings)\n"," return model, torch.cat(valid_image_embeddings)"],"metadata":{"id":"RFD9cjhywakf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, valid_df = make_train_valid_dfs()\n","model, image_embeddings = get_image_embeddings(valid_df, \"best.pt\")"],"metadata":{"id":"xwI4wdTfwcQ0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def find_matches(model, image_embeddings, query, image_filenames, n=9):\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," encoded_query = tokenizer([query])\n"," batch = {\n"," key: torch.tensor(values).to(CFG.device)\n"," for key, values in encoded_query.items()\n"," }\n"," with torch.no_grad():\n"," text_features = model.text_encoder(\n"," input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n"," )\n"," text_embeddings = model.text_projection(text_features)\n","\n"," image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)\n"," text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)\n"," dot_similarity = text_embeddings_n @ image_embeddings_n.T\n","\n"," values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)\n"," matches = [image_filenames[idx] for idx in indices[::5]]\n","\n"," _, axes = plt.subplots(3, 3, figsize=(10, 10))\n"," for match, ax in zip(matches, axes.flatten()):\n"," image = cv2.imread(f\"{CFG.image_path}/{match}\")\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," ax.imshow(image)\n"," ax.axis(\"off\")\n","\n"," plt.show()"],"metadata":{"id":"G4cfYebRwfAi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["find_matches(model,\n"," image_embeddings,\n"," query=\"dogs on the grass\",\n"," image_filenames=valid_df['image'].values,\n"," n=9)"],"metadata":{"id":"5ikO765hwgce"},"execution_count":null,"outputs":[]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPIJzTZmtW8FvyfCtddaeI7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"GaYCNssju8c8"},"outputs":[],"source":["!pip install timm\n","!pip install transformers"]},{"cell_type":"code","source":["import os\n","import cv2\n","import gc\n","import numpy as np\n","import pandas as pd\n","import itertools\n","from tqdm.autonotebook import tqdm\n","import albumentations as A\n","import matplotlib.pyplot as plt\n","\n","import torch\n","from torch import nn\n","import torch.nn.functional as F\n","import timm\n","from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer"],"metadata":{"id":"XXebCRIUvpQk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!pip install kaggle --upgrade\n","# change your name and key\n","os.environ['KAGGLE_USERNAME'] = \"XXXXX\"\n","os.environ['KAGGLE_KEY'] = \"XXXXXXXXXXXXXX\"\n","\n","### For Flickr 8k\n","!kaggle datasets download -d adityajn105/flickr8k\n","!unzip flickr8k.zip\n","dataset = \"8k\"\n","\n","\n","### For Flickr 30k\n","# !kaggle datasets download -d hsankesara/flickr-image-dataset\n","# !unzip flickr-image-dataset.zip\n","# dataset = \"30k\""],"metadata":{"id":"Y1XOe4mfvq2p"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["if dataset == \"8k\":\n"," df = pd.read_csv(\"captions.txt\")\n"," df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]\n"," df.to_csv(\"captions.csv\", index=False)\n"," df = pd.read_csv(\"captions.csv\")\n"," image_path = \"/content/Images\"\n"," captions_path = \"/content\"\n","elif dataset == \"30k\":\n"," df = pd.read_csv(\"/content/flickr30k_images/results.csv\", delimiter=\"|\")\n"," df.columns = ['image', 'caption_number', 'caption']\n"," df['caption'] = df['caption'].str.lstrip()\n"," df['caption_number'] = df['caption_number'].str.lstrip()\n"," df.loc[19999, 'caption_number'] = \"4\"\n"," df.loc[19999, 'caption'] = \"A dog runs across the grass .\"\n"," ids = [id_ for id_ in range(len(df) // 5) for _ in range(5)]\n"," df['id'] = ids\n"," df.to_csv(\"captions.csv\", index=False)\n"," image_path = \"/content/flickr30k_images/flickr30k_images\"\n"," captions_path = \"/content\"\n","\n","df.head()"],"metadata":{"id":"LvS448T4vwHF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CFG:\n"," debug = False\n"," image_path = image_path\n"," captions_path = captions_path\n"," batch_size = 32\n"," num_workers = 2\n"," head_lr = 1e-3\n"," image_encoder_lr = 1e-4\n"," text_encoder_lr = 1e-5\n"," weight_decay = 1e-3\n"," patience = 1\n"," factor = 0.8\n"," epochs = 4\n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n"," model_name = 'resnet50'\n"," image_embedding = 2048\n"," text_encoder_model = \"distilbert-base-uncased\"\n"," text_embedding = 768\n"," text_tokenizer = \"distilbert-base-uncased\"\n"," max_length = 200\n","\n"," pretrained = True # for both image encoder and text encoder\n"," trainable = True # for both image encoder and text encoder\n"," temperature = 1.0\n","\n"," # image size\n"," size = 224\n","\n"," # for projection head; used for both image and text encoders\n"," num_projection_layers = 1\n"," projection_dim = 256\n"," dropout = 0.1"],"metadata":{"id":"LIOAspaYwF_N"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class AvgMeter:\n"," def __init__(self, name=\"Metric\"):\n"," self.name = name\n"," self.reset()\n","\n"," def reset(self):\n"," self.avg, self.sum, self.count = [0] * 3\n","\n"," def update(self, val, count=1):\n"," self.count += count\n"," self.sum += val * count\n"," self.avg = self.sum / self.count\n","\n"," def __repr__(self):\n"," text = f\"{self.name}: {self.avg:.4f}\"\n"," return text\n","\n","def get_lr(optimizer):\n"," for param_group in optimizer.param_groups:\n"," return param_group[\"lr\"]"],"metadata":{"id":"HM_QqFlqwH4a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CLIPDataset(torch.utils.data.Dataset):\n"," def __init__(self, image_filenames, captions, tokenizer, transforms):\n"," \"\"\"\n"," image_filenames and cpations must have the same length; so, if there are\n"," multiple captions for each image, the image_filenames must have repetitive\n"," file names\n"," \"\"\"\n","\n"," self.image_filenames = image_filenames\n"," self.captions = list(captions)\n"," self.encoded_captions = tokenizer(\n"," list(captions), padding=True, truncation=True, max_length=CFG.max_length\n"," )\n"," self.transforms = transforms\n","\n"," def __getitem__(self, idx):\n"," item = {\n"," key: torch.tensor(values[idx])\n"," for key, values in self.encoded_captions.items()\n"," }\n","\n"," image = cv2.imread(f\"{CFG.image_path}/{self.image_filenames[idx]}\")\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image = self.transforms(image=image)['image']\n"," item['image'] = torch.tensor(image).permute(2, 0, 1).float()\n"," item['caption'] = self.captions[idx]\n","\n"," return item\n","\n","\n"," def __len__(self):\n"," return len(self.captions)\n","\n","\n","\n","def get_transforms(mode=\"train\"):\n"," if mode == \"train\":\n"," return A.Compose(\n"," [\n"," A.Resize(CFG.size, CFG.size, always_apply=True),\n"," A.Normalize(max_pixel_value=255.0, always_apply=True),\n"," ]\n"," )\n"," else:\n"," return A.Compose(\n"," [\n"," A.Resize(CFG.size, CFG.size, always_apply=True),\n"," A.Normalize(max_pixel_value=255.0, always_apply=True),\n"," ]\n"," )"],"metadata":{"id":"w2JcEE6gwJu6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ImageEncoder(nn.Module):\n"," \"\"\"\n"," Encode images to a fixed size vector\n"," \"\"\"\n","\n"," def __init__(\n"," self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable\n"," ):\n"," super().__init__()\n"," self.model = timm.create_model(\n"," model_name, pretrained, num_classes=0, global_pool=\"avg\"\n"," )\n"," for p in self.model.parameters():\n"," p.requires_grad = trainable\n","\n"," def forward(self, x):\n"," return self.model(x)"],"metadata":{"id":"mrZ4jsXAwONW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TextEncoder(nn.Module):\n"," def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):\n"," super().__init__()\n"," if pretrained:\n"," self.model = DistilBertModel.from_pretrained(model_name)\n"," else:\n"," self.model = DistilBertModel(config=DistilBertConfig())\n","\n"," for p in self.model.parameters():\n"," p.requires_grad = trainable\n","\n"," # we are using the CLS token hidden representation as the sentence's embedding\n"," self.target_token_idx = 0\n","\n"," def forward(self, input_ids, attention_mask):\n"," output = self.model(input_ids=input_ids, attention_mask=attention_mask)\n"," last_hidden_state = output.last_hidden_state\n"," return last_hidden_state[:, self.target_token_idx, :]"],"metadata":{"id":"ju9a8O0JwQBx"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ProjectionHead(nn.Module):\n"," def __init__(\n"," self,\n"," embedding_dim,\n"," projection_dim=CFG.projection_dim,\n"," dropout=CFG.dropout\n"," ):\n"," super().__init__()\n"," self.projection = nn.Linear(embedding_dim, projection_dim)\n"," self.gelu = nn.GELU()\n"," self.fc = nn.Linear(projection_dim, projection_dim)\n"," self.dropout = nn.Dropout(dropout)\n"," self.layer_norm = nn.LayerNorm(projection_dim)\n","\n"," def forward(self, x):\n"," projected = self.projection(x)\n"," x = self.gelu(projected)\n"," x = self.fc(x)\n"," x = self.dropout(x)\n"," x = x + projected\n"," x = self.layer_norm(x)\n"," return x"],"metadata":{"id":"UXREwtR9wRpI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CLIPModel(nn.Module):\n"," def __init__(\n"," self,\n"," temperature=CFG.temperature,\n"," image_embedding=CFG.image_embedding,\n"," text_embedding=CFG.text_embedding,\n"," ):\n"," super().__init__()\n"," self.image_encoder = ImageEncoder()\n"," self.text_encoder = TextEncoder()\n"," self.image_projection = ProjectionHead(embedding_dim=image_embedding)\n"," self.text_projection = ProjectionHead(embedding_dim=text_embedding)\n"," self.temperature = temperature\n","\n"," def forward(self, batch):\n"," # Getting Image and Text Features\n"," image_features = self.image_encoder(batch[\"image\"])\n"," text_features = self.text_encoder(\n"," input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n"," )\n"," # Getting Image and Text Embeddings (with same dimension)\n"," image_embeddings = self.image_projection(image_features)\n"," text_embeddings = self.text_projection(text_features)\n","\n"," # Calculating the Loss\n"," logits = (text_embeddings @ image_embeddings.T) / self.temperature\n"," images_similarity = image_embeddings @ image_embeddings.T\n"," texts_similarity = text_embeddings @ text_embeddings.T\n"," targets = F.softmax(\n"," (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1\n"," )\n"," texts_loss = cross_entropy(logits, targets, reduction='none')\n"," images_loss = cross_entropy(logits.T, targets.T, reduction='none')\n"," loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)\n"," return loss.mean()\n","\n","\n","def cross_entropy(preds, targets, reduction='none'):\n"," log_softmax = nn.LogSoftmax(dim=-1)\n"," loss = (-targets * log_softmax(preds)).sum(1)\n"," if reduction == \"none\":\n"," return loss\n"," elif reduction == \"mean\":\n"," return loss.mean()"],"metadata":{"id":"1_fhZIC4wTQA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# A simple Example\n","\n","batch_size = 4\n","dim = 256\n","embeddings = torch.randn(batch_size, dim)\n","out = embeddings @ embeddings.T\n","print(F.softmax(out, dim=-1))"],"metadata":{"id":"ajzH97wGwVwD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def make_train_valid_dfs():\n"," dataframe = pd.read_csv(f\"{CFG.captions_path}/captions.csv\")\n"," max_id = dataframe[\"id\"].max() + 1 if not CFG.debug else 100\n"," image_ids = np.arange(0, max_id)\n"," np.random.seed(42)\n"," valid_ids = np.random.choice(\n"," image_ids, size=int(0.2 * len(image_ids)), replace=False\n"," )\n"," train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]\n"," train_dataframe = dataframe[dataframe[\"id\"].isin(train_ids)].reset_index(drop=True)\n"," valid_dataframe = dataframe[dataframe[\"id\"].isin(valid_ids)].reset_index(drop=True)\n"," return train_dataframe, valid_dataframe\n","\n","\n","def build_loaders(dataframe, tokenizer, mode):\n"," transforms = get_transforms(mode=mode)\n"," dataset = CLIPDataset(\n"," dataframe[\"image\"].values,\n"," dataframe[\"caption\"].values,\n"," tokenizer=tokenizer,\n"," transforms=transforms,\n"," )\n"," dataloader = torch.utils.data.DataLoader(\n"," dataset,\n"," batch_size=CFG.batch_size,\n"," num_workers=CFG.num_workers,\n"," shuffle=True if mode == \"train\" else False,\n"," )\n"," return dataloader"],"metadata":{"id":"mYNL9TCewXMG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def train_epoch(model, train_loader, optimizer, lr_scheduler, step):\n"," loss_meter = AvgMeter()\n"," tqdm_object = tqdm(train_loader, total=len(train_loader))\n"," for batch in tqdm_object:\n"," batch = {k: v.to(CFG.device) for k, v in batch.items() if k != \"caption\"}\n"," loss = model(batch)\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," if step == \"batch\":\n"," lr_scheduler.step()\n","\n"," count = batch[\"image\"].size(0)\n"," loss_meter.update(loss.item(), count)\n","\n"," tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))\n"," return loss_meter\n","\n","\n","def valid_epoch(model, valid_loader):\n"," loss_meter = AvgMeter()\n","\n"," tqdm_object = tqdm(valid_loader, total=len(valid_loader))\n"," for batch in tqdm_object:\n"," batch = {k: v.to(CFG.device) for k, v in batch.items() if k != \"caption\"}\n"," loss = model(batch)\n","\n"," count = batch[\"image\"].size(0)\n"," loss_meter.update(loss.item(), count)\n","\n"," tqdm_object.set_postfix(valid_loss=loss_meter.avg)\n"," return loss_meter\n","\n","\n","def main():\n"," train_df, valid_df = make_train_valid_dfs()\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," train_loader = build_loaders(train_df, tokenizer, mode=\"train\")\n"," valid_loader = build_loaders(valid_df, tokenizer, mode=\"valid\")\n","\n","\n"," model = CLIPModel().to(CFG.device)\n"," params = [\n"," {\"params\": model.image_encoder.parameters(), \"lr\": CFG.image_encoder_lr},\n"," {\"params\": model.text_encoder.parameters(), \"lr\": CFG.text_encoder_lr},\n"," {\"params\": itertools.chain(\n"," model.image_projection.parameters(), model.text_projection.parameters()\n"," ), \"lr\": CFG.head_lr, \"weight_decay\": CFG.weight_decay}\n"," ]\n"," optimizer = torch.optim.AdamW(params, weight_decay=0.)\n"," lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n"," optimizer, mode=\"min\", patience=CFG.patience, factor=CFG.factor\n"," )\n"," step = \"epoch\"\n","\n"," best_loss = float('inf')\n"," for epoch in range(CFG.epochs):\n"," print(f\"Epoch: {epoch + 1}\")\n"," model.train()\n"," train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)\n"," model.eval()\n"," with torch.no_grad():\n"," valid_loss = valid_epoch(model, valid_loader)\n","\n"," if valid_loss.avg < best_loss:\n"," best_loss = valid_loss.avg\n"," torch.save(model.state_dict(), \"best.pt\")\n"," print(\"Saved Best Model!\")\n","\n"," lr_scheduler.step(valid_loss.avg)"],"metadata":{"id":"ftucX7sqwYy8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_image_embeddings(valid_df, model_path):\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," valid_loader = build_loaders(valid_df, tokenizer, mode=\"valid\")\n","\n"," model = CLIPModel().to(CFG.device)\n"," model.load_state_dict(torch.load(model_path, map_location=CFG.device))\n"," model.eval()\n","\n"," valid_image_embeddings = []\n"," with torch.no_grad():\n"," for batch in tqdm(valid_loader):\n"," image_features = model.image_encoder(batch[\"image\"].to(CFG.device))\n"," image_embeddings = model.image_projection(image_features)\n"," valid_image_embeddings.append(image_embeddings)\n"," return model, torch.cat(valid_image_embeddings)"],"metadata":{"id":"RFD9cjhywakf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, valid_df = make_train_valid_dfs()\n","model, image_embeddings = get_image_embeddings(valid_df, \"best.pt\")"],"metadata":{"id":"xwI4wdTfwcQ0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def find_matches(model, image_embeddings, query, image_filenames, n=9):\n"," tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)\n"," encoded_query = tokenizer([query])\n"," batch = {\n"," key: torch.tensor(values).to(CFG.device)\n"," for key, values in encoded_query.items()\n"," }\n"," with torch.no_grad():\n"," text_features = model.text_encoder(\n"," input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n"," )\n"," text_embeddings = model.text_projection(text_features)\n","\n"," image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)\n"," text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)\n"," dot_similarity = text_embeddings_n @ image_embeddings_n.T\n","\n"," values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)\n"," matches = [image_filenames[idx] for idx in indices[::5]]\n","\n"," _, axes = plt.subplots(3, 3, figsize=(10, 10))\n"," for match, ax in zip(matches, axes.flatten()):\n"," image = cv2.imread(f\"{CFG.image_path}/{match}\")\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," ax.imshow(image)\n"," ax.axis(\"off\")\n","\n"," plt.show()"],"metadata":{"id":"G4cfYebRwfAi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["find_matches(model,\n"," image_embeddings,\n"," query=\"dogs on the grass\",\n"," image_filenames=valid_df['image'].values,\n"," n=9)"],"metadata":{"id":"5ikO765hwgce"},"execution_count":null,"outputs":[]}]}
\ No newline at end of file
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNdyPaV9EsYySAFZkGn8PS0"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install openai"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"s-v46x9zDqJ4","executionInfo":{"status":"ok","timestamp":1707619579730,"user_tz":-480,"elapsed":9172,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"90c11041-3189-4c8e-8ae4-8265f4291b2e"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/226.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.9/226.7 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.26.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.1)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.2-py3-none-any.whl (76 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","llmx 0.0.15a0 requires cohere, which is not installed.\n","llmx 0.0.15a0 requires tiktoken, which is not installed.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed h11-0.14.0 httpcore-1.0.2 httpx-0.26.0 openai-1.12.0\n"]}]},{"cell_type":"code","execution_count":4,"metadata":{"id":"nl3GO59eCxQv","executionInfo":{"status":"ok","timestamp":1707619605853,"user_tz":-480,"elapsed":6067,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"outputs":[],"source":["from IPython.display import display, Image\n","from openai import OpenAI\n","import os\n","import pandas as pd\n","import json\n","import io\n","from PIL import Image\n","import requests\n","\n","from google.colab import userdata\n","API_key=userdata.get('OpenAI-Key')\n","\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", API_key))\n","\n","#Lets import some helper functions for assistants from https://cookbook.openai.com/examples/assistants_api_overview_python\n","def show_json(obj):\n"," display(json.loads(obj.model_dump_json()))\n","\n","def submit_message(assistant_id, thread, user_message,file_ids=None):\n"," params = {\n"," 'thread_id': thread.id,\n"," 'role': 'user',\n"," 'content': user_message,\n"," }\n"," if file_ids:\n"," params['file_ids']=file_ids\n","\n"," client.beta.threads.messages.create(\n"," **params\n",")\n"," return client.beta.threads.runs.create(\n"," thread_id=thread.id,\n"," assistant_id=assistant_id,\n",")\n","\n","def get_response(thread):\n"," return client.beta.threads.messages.list(thread_id=thread.id)"]},{"cell_type":"code","source":["financial_data_path = 'NotRealCorp_financial_data.json'\n","financial_data = pd.read_json(financial_data_path)\n","financial_data.head(5)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":206},"id":"Mv4kJr1sEAPq","executionInfo":{"status":"ok","timestamp":1707619718449,"user_tz":-480,"elapsed":5,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3197abb8-ac8d-4400-859c-5502df91c76e"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":[" Year Quarter Distribution channel Revenue ($M) Costs ($M) \\\n","0 2021 Q1 Online Sales 1.50 1.301953 \n","1 2021 Q1 Direct Sales 1.50 1.380809 \n","2 2021 Q1 Retail Partners 1.50 1.348246 \n","3 2021 Q2 Online Sales 1.52 1.308608 \n","4 2021 Q2 Direct Sales 1.52 1.413305 \n","\n"," Customer count Time \n","0 150 2021 Q1 \n","1 151 2021 Q1 \n","2 152 2021 Q1 \n","3 152 2021 Q2 \n","4 153 2021 Q2 "],"text/html":["\n"," <div id=\"df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4\" class=\"colab-df-container\">\n"," <div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>Year</th>\n"," <th>Quarter</th>\n"," <th>Distribution channel</th>\n"," <th>Revenue ($M)</th>\n"," <th>Costs ($M)</th>\n"," <th>Customer count</th>\n"," <th>Time</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>0</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Online Sales</td>\n"," <td>1.50</td>\n"," <td>1.301953</td>\n"," <td>150</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>1</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Direct Sales</td>\n"," <td>1.50</td>\n"," <td>1.380809</td>\n"," <td>151</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>2</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Retail Partners</td>\n"," <td>1.50</td>\n"," <td>1.348246</td>\n"," <td>152</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>3</th>\n"," <td>2021</td>\n"," <td>Q2</td>\n"," <td>Online Sales</td>\n"," <td>1.52</td>\n"," <td>1.308608</td>\n"," <td>152</td>\n"," <td>2021 Q2</td>\n"," </tr>\n"," <tr>\n"," <th>4</th>\n"," <td>2021</td>\n"," <td>Q2</td>\n"," <td>Direct Sales</td>\n"," <td>1.52</td>\n"," <td>1.413305</td>\n"," <td>153</td>\n"," <td>2021 Q2</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>\n"," <div class=\"colab-df-buttons\">\n","\n"," <div class=\"colab-df-container\">\n"," <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4')\"\n"," title=\"Convert this dataframe to an interactive table.\"\n"," style=\"display:none;\">\n","\n"," <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n"," <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n"," </svg>\n"," </button>\n","\n"," <style>\n"," .colab-df-container {\n"," display:flex;\n"," gap: 12px;\n"," }\n","\n"," .colab-df-convert {\n"," background-color: #E8F0FE;\n"," border: none;\n"," border-radius: 50%;\n"," cursor: pointer;\n"," display: none;\n"," fill: #1967D2;\n"," height: 32px;\n"," padding: 0 0 0 0;\n"," width: 32px;\n"," }\n","\n"," .colab-df-convert:hover {\n"," background-color: #E2EBFA;\n"," box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n"," fill: #174EA6;\n"," }\n","\n"," .colab-df-buttons div {\n"," margin-bottom: 4px;\n"," }\n","\n"," [theme=dark] .colab-df-convert {\n"," background-color: #3B4455;\n"," fill: #D2E3FC;\n"," }\n","\n"," [theme=dark] .colab-df-convert:hover {\n"," background-color: #434B5C;\n"," box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n"," filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n"," fill: #FFFFFF;\n"," }\n"," </style>\n","\n"," <script>\n"," const buttonEl =\n"," document.querySelector('#df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4 button.colab-df-convert');\n"," buttonEl.style.display =\n"," google.colab.kernel.accessAllowed ? 'block' : 'none';\n","\n"," async function convertToInteractive(key) {\n"," const element = document.querySelector('#df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4');\n"," const dataTable =\n"," await google.colab.kernel.invokeFunction('convertToInteractive',\n"," [key], {});\n"," if (!dataTable) return;\n","\n"," const docLinkHtml = 'Like what you see? Visit the ' +\n"," '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n"," + ' to learn more about interactive tables.';\n"," element.innerHTML = '';\n"," dataTable['output_type'] = 'display_data';\n"," await google.colab.output.renderOutput(dataTable, element);\n"," const docLink = document.createElement('div');\n"," docLink.innerHTML = docLinkHtml;\n"," element.appendChild(docLink);\n"," }\n"," </script>\n"," </div>\n","\n","\n","<div id=\"df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6\">\n"," <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6')\"\n"," title=\"Suggest charts\"\n"," style=\"display:none;\">\n","\n","<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n"," width=\"24px\">\n"," <g>\n"," <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n"," </g>\n","</svg>\n"," </button>\n","\n","<style>\n"," .colab-df-quickchart {\n"," --bg-color: #E8F0FE;\n"," --fill-color: #1967D2;\n"," --hover-bg-color: #E2EBFA;\n"," --hover-fill-color: #174EA6;\n"," --disabled-fill-color: #AAA;\n"," --disabled-bg-color: #DDD;\n"," }\n","\n"," [theme=dark] .colab-df-quickchart {\n"," --bg-color: #3B4455;\n"," --fill-color: #D2E3FC;\n"," --hover-bg-color: #434B5C;\n"," --hover-fill-color: #FFFFFF;\n"," --disabled-bg-color: #3B4455;\n"," --disabled-fill-color: #666;\n"," }\n","\n"," .colab-df-quickchart {\n"," background-color: var(--bg-color);\n"," border: none;\n"," border-radius: 50%;\n"," cursor: pointer;\n"," display: none;\n"," fill: var(--fill-color);\n"," height: 32px;\n"," padding: 0;\n"," width: 32px;\n"," }\n","\n"," .colab-df-quickchart:hover {\n"," background-color: var(--hover-bg-color);\n"," box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n"," fill: var(--button-hover-fill-color);\n"," }\n","\n"," .colab-df-quickchart-complete:disabled,\n"," .colab-df-quickchart-complete:disabled:hover {\n"," background-color: var(--disabled-bg-color);\n"," fill: var(--disabled-fill-color);\n"," box-shadow: none;\n"," }\n","\n"," .colab-df-spinner {\n"," border: 2px solid var(--fill-color);\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," animation:\n"," spin 1s steps(1) infinite;\n"," }\n","\n"," @keyframes spin {\n"," 0% {\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," border-left-color: var(--fill-color);\n"," }\n"," 20% {\n"," border-color: transparent;\n"," border-left-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," }\n"," 30% {\n"," border-color: transparent;\n"," border-left-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," border-right-color: var(--fill-color);\n"," }\n"," 40% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," }\n"," 60% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," }\n"," 80% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," border-bottom-color: var(--fill-color);\n"," }\n"," 90% {\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," }\n"," }\n","</style>\n","\n"," <script>\n"," async function quickchart(key) {\n"," const quickchartButtonEl =\n"," document.querySelector('#' + key + ' button');\n"," quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n"," quickchartButtonEl.classList.add('colab-df-spinner');\n"," try {\n"," const charts = await google.colab.kernel.invokeFunction(\n"," 'suggestCharts', [key], {});\n"," } catch (error) {\n"," console.error('Error during call to suggestCharts:', error);\n"," }\n"," quickchartButtonEl.classList.remove('colab-df-spinner');\n"," quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n"," }\n"," (() => {\n"," let quickchartButtonEl =\n"," document.querySelector('#df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6 button');\n"," quickchartButtonEl.style.display =\n"," google.colab.kernel.accessAllowed ? 'block' : 'none';\n"," })();\n"," </script>\n","</div>\n","\n"," </div>\n"," </div>\n"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["file = client.files.create(\n"," file=open('NotRealCorp_financial_data.json',\"rb\"),\n"," purpose='assistants',\n",")"],"metadata":{"id":"naRttfJsEcbR","executionInfo":{"status":"ok","timestamp":1707619780967,"user_tz":-480,"elapsed":1034,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["assistant = client.beta.assistants.create(\n"," instructions=\"You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization\",\n"," model=\"gpt-4-1106-preview\",\n"," tools=[{\"type\": \"code_interpreter\"}],\n"," file_ids=[file.id]\n",")"],"metadata":{"id":"bv21EWvKEjWP","executionInfo":{"status":"ok","timestamp":1707619796761,"user_tz":-480,"elapsed":511,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["thread = client.beta.threads.create(\n"," messages=[\n"," {\n"," \"role\": \"user\",\n"," \"content\": \"Calculate profit (revenue minus cost) by quarter and year, and visualize as a line plot across the distribution channels, where the colors of the lines are green, light red, and light blue\",\n"," \"file_ids\": [file.id]\n"," }\n"," ]\n",")"],"metadata":{"id":"e9od4SDmEk_b","executionInfo":{"status":"ok","timestamp":1707619809125,"user_tz":-480,"elapsed":527,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["run = client.beta.threads.runs.create(\n"," thread_id=thread.id,\n"," assistant_id=assistant.id,\n",")"],"metadata":{"id":"yZ2zuB63Entp","executionInfo":{"status":"ok","timestamp":1707619818426,"user_tz":-480,"elapsed":504,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["messages = client.beta.threads.messages.list(thread_id=thread.id)"],"metadata":{"id":"J8laxybUErur","executionInfo":{"status":"ok","timestamp":1707619837176,"user_tz":-480,"elapsed":512,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["import time\n","\n","while True:\n"," messages = client.beta.threads.messages.list(thread_id=thread.id)\n"," try:\n"," #See if image has been created\n"," messages.data[0].content[0].image_file\n"," #Sleep to make sure run has completed\n"," time.sleep(5)\n"," print('Plot created!')\n"," break\n"," except:\n"," time.sleep(10)\n"," print('Assistant still working...')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dsUwZKk6Eu7o","executionInfo":{"status":"ok","timestamp":1707619978003,"user_tz":-480,"elapsed":130282,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"75c2d621-24a1-44ee-a011-844a6f39ca31"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Plot created!\n"]}]},{"cell_type":"code","source":["messages = client.beta.threads.messages.list(thread_id=thread.id)\n","[message.content[0] for message in messages.data]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PJeTKUdGEz_1","executionInfo":{"status":"ok","timestamp":1707620018940,"user_tz":-480,"elapsed":522,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"1c790ac4-4c14-4025-bce5-e5a16ae96744"},"execution_count":15,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[MessageContentImageFile(image_file=ImageFile(file_id='file-0TaYohpWFuqF4SOCMhuy2ujI'), type='image_file'),\n"," MessageContentText(text=Text(annotations=[], value=\"It seems I made an error when attempting to drop the 'Year' and 'Quarter' columns from the pivot table's columns MultiIndex. I need to address this by correctly referencing these columns in the context of a MultiIndex. Let's correct this and try plotting again.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It appears there is an issue with plotting the multi-index ('Year', 'Quarter') directly. To resolve this, we need to convert the multi-index into a single string that represents both Year and Quarter. I will adjust the DataFrame index and attempt the visualization again.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='With the profit calculated and data grouped by \"Year\", \"Quarter\", and \"Distribution channel\", we can now move on to visualizing this information as a line plot.\\n\\nWe will use three different colors for the lines representing different distribution channels as specified:\\n\\n- Online Sales: Light Blue\\n- Direct Sales: Light Red\\n- Retail Partners: Green\\n\\nLet\\'s proceed with the visualization.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='We have successfully structured the data into a pandas DataFrame. Now we can proceed to calculate the profit by subtracting the \"Costs ($M)\" column from the \"Revenue ($M)\" column, and then group and visualize the data by year, quarter, and distribution channel. Let\\'s start with the calculation and grouping.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='The content of the \"Year\" key appears to be a dictionary where each entry corresponds to a row index and the respective year. This suggests that the keys of the main JSON object represent columns that are in dictionaries themselves with row index as keys.\\n\\nI will now construct a pandas DataFrame by iterating over these keys and combining the corresponding dictionaries into a structured table.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='The JSON object contains keys corresponding to what seems like different data columns like \"Year\", \"Quarter\", \"Distribution channel\", \"Revenue\", \"Costs\", \"Customer count\", and a combined \"Time\" column.\\n\\nIt might be a bit tricky to deal with this structure directly, so let\\'s see if we can correctly parse and normalize this JSON by columns into a pandas DataFrame. I will explore one of the columns to see its data format.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It looks like the data has been loaded incorrectly, resulting in an odd DataFrame structure that seems to include serialized JSON. I will try to unpack this serialized JSON data correctly into a usable pandas DataFrame. Let's proceed by examining the keys in the JSON object so we can determine the correct approach to normalize the data into a DataFrame.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It appears that the attempt to read the file as a typical CSV or Excel file has failed, resulting in an empty DataFrame with an unusual column structure that suggests the data might be stored in a JSON format within the file.\\n\\nI'll now attempt to load the data as JSON to see if that correctly interprets the file structure.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"To calculate profit by quarter and year and visualize it across distribution channels as requested, I first need to inspect the contents of the uploaded file to understand its structure (column names, data types, etc.).\\n\\nLet's begin by loading the file and taking a look at the first few rows to determine how to proceed.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='Calculate profit (revenue minus cost) by quarter and year, and visualize as a line plot across the distribution channels, where the colors of the lines are green, light red, and light blue'), type='text')]"]},"metadata":{},"execution_count":15}]},{"cell_type":"code","source":["# Quick helper function to convert our output file to a png\n","def convert_file_to_png(file_id, write_path):\n"," data = client.files.content(file_id)\n"," data_bytes = data.read()\n"," with open(write_path, \"wb\") as file:\n"," file.write(data_bytes)"],"metadata":{"id":"CWyogxdhE5Xl","executionInfo":{"status":"ok","timestamp":1707620019701,"user_tz":-480,"elapsed":4,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["plot_file_id = messages.data[0].content[0].image_file.file_id\n","image_path = \"NotRealCorp_chart.png\"\n","convert_file_to_png(plot_file_id,image_path)\n","\n","#Upload\n","plot_file = client.files.create(\n"," file=open(image_path, \"rb\"),\n"," purpose='assistants'\n",")"],"metadata":{"id":"y7x0uRvNE6y-","executionInfo":{"status":"ok","timestamp":1707620195449,"user_tz":-480,"elapsed":2152,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["submit_message(assistant.id,thread,\"Give me two medium length sentences (~20-30 words per sentence) of the \\\n"," most important insights from the plot you just created.\\\n"," These will be used for a slide deck, and they should be about the\\\n"," 'so what' behind the data.\"\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"V_ITOsiCE98C","executionInfo":{"status":"ok","timestamp":1707620206406,"user_tz":-480,"elapsed":1026,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"8c169d2e-9c74-477d-e90d-7aaaf7cf0e78"},"execution_count":19,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_GWFPQYRVdxr0LVyFQWOgy0mB', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620205, expires_at=1707620805, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":19}]},{"cell_type":"code","source":["# Hard coded wait for a response, as the assistant may iterate on the bullets.\n","time.sleep(10)\n","response = get_response(thread)\n","bullet_points = response.data[0].content[0].text.value\n","print(bullet_points)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"A1UdmH-nE_wK","executionInfo":{"status":"ok","timestamp":1707620227601,"user_tz":-480,"elapsed":11415,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"095e74a5-ff2c-4ea3-87f8-8d9f647eed91"},"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["The line plot illustrates that Online Sales consistently generates the highest profit across all quarters, indicating a strong market presence and effective online strategies. Despite seasonal fluctuations, Retail Partners show a stable profit trend, while Direct Sales appear to be the least profitable channel with more pronounced variances over time.\n"]}]},{"cell_type":"code","source":["submit_message(assistant.id,thread,\"Given the plot and bullet points you created,\\\n"," come up with a very brief title for a slide. It should reflect just the main insights you came up with.\"\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MrNeJfmkFBmI","executionInfo":{"status":"ok","timestamp":1707620228114,"user_tz":-480,"elapsed":514,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"ca5ca6c7-d057-4303-c5a2-775124525d76"},"execution_count":21,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_axOn3g0XObuWLz01NBb8RJWa', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620227, expires_at=1707620827, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":21}]},{"cell_type":"code","source":["#Wait as assistant may take a few steps\n","time.sleep(10)\n","response = get_response(thread)\n","title = response.data[0].content[0].text.value\n","print(title)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"UPlwpe3UFDk9","executionInfo":{"status":"ok","timestamp":1707620238328,"user_tz":-480,"elapsed":10216,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"65565240-b594-4dbd-e3c4-83b5a9339446"},"execution_count":22,"outputs":[{"output_type":"stream","name":"stdout","text":["\"Online Sales Lead Profits Across Channels with Steady Retail Partner Performance\"\n"]}]},{"cell_type":"code","source":["company_summary = \"NotReal Corp is a prominent hardware company that manufactures and sells processors, graphics cards and other essential computer hardware.\"\n","response = client.images.generate(\n"," model='dall-e-3',\n"," prompt=f\"given this company summary {company_summary}, create an inspirational \\\n"," photo showing the growth and path forward. This will be used at a quarterly\\\n"," financial planning meeting\",\n"," size=\"1024x1024\",\n"," quality=\"hd\",\n"," n=1\n",")\n","image_url = response.data[0].url"],"metadata":{"id":"0nWlwWIqFEsT","executionInfo":{"status":"ok","timestamp":1707620253115,"user_tz":-480,"elapsed":14790,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":23,"outputs":[]},{"cell_type":"code","source":["dalle_img_path = 'dalle_image.png'\n","img = requests.get(image_url)\n","\n","#Save locally\n","with open(dalle_img_path,'wb') as file:\n"," file.write(img.content)\n","\n","#Upload\n","dalle_file = client.files.create(\n"," file=open(dalle_img_path, \"rb\"),\n"," purpose='assistants'\n",")"],"metadata":{"id":"2562HE27FJeF","executionInfo":{"status":"ok","timestamp":1707620275055,"user_tz":-480,"elapsed":6318,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":25,"outputs":[]},{"cell_type":"code","source":["title_template = \"\"\"\n","from pptx import Presentation\n","from pptx.util import Inches, Pt\n","from pptx.enum.text import PP_PARAGRAPH_ALIGNMENT\n","from pptx.dml.color import RGBColor\n","\n","# Create a new presentation object\n","prs = Presentation()\n","\n","# Add a blank slide layout\n","blank_slide_layout = prs.slide_layouts[6]\n","slide = prs.slides.add_slide(blank_slide_layout)\n","\n","# Set the background color of the slide to black\n","background = slide.background\n","fill = background.fill\n","fill.solid()\n","fill.fore_color.rgb = RGBColor(0, 0, 0)\n","\n","# Add image to the left side of the slide with a margin at the top and bottom\n","left = Inches(0)\n","top = Inches(0)\n","height = prs.slide_height\n","width = prs.slide_width * 3/5\n","pic = slide.shapes.add_picture(image_path, left, top, width=width, height=height)\n","\n","# Add title text box positioned higher\n","left = prs.slide_width * 3/5\n","top = Inches(2)\n","width = prs.slide_width * 2/5\n","height = Inches(1)\n","title_box = slide.shapes.add_textbox(left, top, width, height)\n","title_frame = title_box.text_frame\n","title_p = title_frame.add_paragraph()\n","title_p.text = title_text\n","title_p.font.bold = True\n","title_p.font.size = Pt(38)\n","title_p.font.color.rgb = RGBColor(255, 255, 255)\n","title_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\n","# Add subtitle text box\n","left = prs.slide_width * 3/5\n","top = Inches(3)\n","width = prs.slide_width * 2/5\n","height = Inches(1)\n","subtitle_box = slide.shapes.add_textbox(left, top, width, height)\n","subtitle_frame = subtitle_box.text_frame\n","subtitle_p = subtitle_frame.add_paragraph()\n","subtitle_p.text = subtitle_text\n","subtitle_p.font.size = Pt(22)\n","subtitle_p.font.color.rgb = RGBColor(255, 255, 255)\n","subtitle_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\"\"\"\n","\n","data_vis_template = \"\"\"\n","from pptx import Presentation\n","from pptx.util import Inches, Pt\n","from pptx.enum.text import PP_PARAGRAPH_ALIGNMENT\n","from pptx.dml.color import RGBColor\n","\n","# Create a new presentation object\n","prs = Presentation()\n","\n","# Add a blank slide layout\n","blank_slide_layout = prs.slide_layouts[6]\n","slide = prs.slides.add_slide(blank_slide_layout)\n","\n","# Set the background color of the slide to black\n","background = slide.background\n","fill = background.fill\n","fill.solid()\n","fill.fore_color.rgb = RGBColor(0, 0, 0)\n","\n","# Define placeholders\n","image_path = data_vis_img\n","title_text = \"Maximizing Profits: The Dominance of Online Sales & Direct Sales Optimization\"\n","bullet_points = \"• Online Sales consistently lead in profitability across quarters, indicating a strong digital market presence.\\n• Direct Sales show fluctuations, suggesting variable performance and the need for targeted improvements in that channel.\"\n","\n","# Add image placeholder on the left side of the slide\n","left = Inches(0.2)\n","top = Inches(1.8)\n","height = prs.slide_height - Inches(3)\n","width = prs.slide_width * 3/5\n","pic = slide.shapes.add_picture(image_path, left, top, width=width, height=height)\n","\n","# Add title text spanning the whole width\n","left = Inches(0)\n","top = Inches(0)\n","width = prs.slide_width\n","height = Inches(1)\n","title_box = slide.shapes.add_textbox(left, top, width, height)\n","title_frame = title_box.text_frame\n","title_frame.margin_top = Inches(0.1)\n","title_p = title_frame.add_paragraph()\n","title_p.text = title_text\n","title_p.font.bold = True\n","title_p.font.size = Pt(28)\n","title_p.font.color.rgb = RGBColor(255, 255, 255)\n","title_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\n","# Add hardcoded \"Key Insights\" text and bullet points\n","left = prs.slide_width * 2/3\n","top = Inches(1.5)\n","width = prs.slide_width * 1/3\n","height = Inches(4.5)\n","insights_box = slide.shapes.add_textbox(left, top, width, height)\n","insights_frame = insights_box.text_frame\n","insights_p = insights_frame.add_paragraph()\n","insights_p.text = \"Key Insights:\"\n","insights_p.font.bold = True\n","insights_p.font.size = Pt(24)\n","insights_p.font.color.rgb = RGBColor(0, 128, 100)\n","insights_p.alignment = PP_PARAGRAPH_ALIGNMENT.LEFT\n","insights_frame.add_paragraph()\n","\n","\n","bullet_p = insights_frame.add_paragraph()\n","bullet_p.text = bullet_points\n","bullet_p.font.size = Pt(12)\n","bullet_p.font.color.rgb = RGBColor(255, 255, 255)\n","bullet_p.line_spacing = 1.5\n","\"\"\""],"metadata":{"id":"0i-dDVFkFOaO","executionInfo":{"status":"ok","timestamp":1707620278192,"user_tz":-480,"elapsed":502,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":26,"outputs":[]},{"cell_type":"code","source":["title_text = \"NotRealCorp\"\n","subtitle_text = \"Quarterly financial planning meeting, Q3 2023\""],"metadata":{"id":"E5prceoBGQcz","executionInfo":{"status":"ok","timestamp":1707620339642,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":28,"outputs":[]},{"cell_type":"code","source":["submit_message(assistant.id,thread,f\"Use the included code template to create a PPTX slide that follows the template format, but uses the image, company name/title, and document name/subtitle included:\\\n","{title_template}. IMPORTANT: Use the image file included in this message as the image_path image in this first slide, and use the Company Name {title_text} as the title_text variable, and \\\n"," use the subtitle_text {subtitle_text} a the subtitle_text variable. \\\n"," NEST, create a SECOND slide using the following code template: {data_vis_template} to create a PPTX slide that follows the template format, but uses the company name/title, and document name/subtitle included:\\\n","{data_vis_template}. IMPORTANT: Use the line plot image, that is the second attached image in this message, that you created earlier in the thread as the data_vis_img image, and use the data visualization title that you created earlier for the variable title_text, and\\\n"," the bullet points of insights you created earlier for the bullet_points variable. Output these TWO SLIDES as a .pptx file. Make sure the output is two slides, with each slide matching the respective template given in this message.\",\n"," file_ids=[dalle_file.id, plot_file.id]\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NAHKthE5Ga02","executionInfo":{"status":"ok","timestamp":1707620340656,"user_tz":-480,"elapsed":1016,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"33508da3-71e5-4df7-9ea8-26e6eae66f21"},"execution_count":29,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_eQvuBgDgLPEBDCE3C38UAnUM', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620339, expires_at=1707620939, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":29}]},{"cell_type":"code","source":["#May take 1-3 mins\n","while True:\n"," try:\n"," response = get_response(thread)\n"," pptx_id = response.data[0].content[0].text.annotations[0].file_path.file_id\n"," print(\"Successfully retrieved pptx_id:\", pptx_id)\n"," break\n"," except Exception as e:\n"," print(\"Assistant still working on PPTX...\")\n"," time.sleep(10)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XIPEkk4TGbMS","executionInfo":{"status":"ok","timestamp":1707620436104,"user_tz":-480,"elapsed":95450,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"4f281d7b-019d-410d-9218-60ca56479b64"},"execution_count":30,"outputs":[{"output_type":"stream","name":"stdout","text":["Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Successfully retrieved pptx_id: file-NwArFbjmmFw0tnboToe6f18N\n"]}]},{"cell_type":"code","source":["pptx_id = response.data[0].content[0].text.annotations[0].file_path.file_id\n","ppt_file= client.files.content(pptx_id)\n","file_obj = io.BytesIO(ppt_file.read())\n","with open(\"created_slides.pptx\", \"wb\") as f:\n"," f.write(file_obj.getbuffer())"],"metadata":{"id":"CtqZHLolGbyT","executionInfo":{"status":"ok","timestamp":1707620446153,"user_tz":-480,"elapsed":1151,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":32,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"MxJ_Co2ZGoxf"},"execution_count":null,"outputs":[]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNdyPaV9EsYySAFZkGn8PS0"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install openai"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"s-v46x9zDqJ4","executionInfo":{"status":"ok","timestamp":1707619579730,"user_tz":-480,"elapsed":9172,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"90c11041-3189-4c8e-8ae4-8265f4291b2e"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting openai\n"," Downloading openai-1.12.0-py3-none-any.whl (226 kB)\n","\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/226.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.9/226.7 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.7/226.7 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n","Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n","Collecting httpx<1,>=0.23.0 (from openai)\n"," Downloading httpx-0.26.0-py3-none-any.whl (75 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.6.1)\n","Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.0)\n","Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.1)\n","Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai) (4.9.0)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.6)\n","Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.2.2)\n","Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n"," Downloading httpcore-1.0.2-py3-none-any.whl (76 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n"," Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n","Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.16.2)\n","Installing collected packages: h11, httpcore, httpx, openai\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","llmx 0.0.15a0 requires cohere, which is not installed.\n","llmx 0.0.15a0 requires tiktoken, which is not installed.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed h11-0.14.0 httpcore-1.0.2 httpx-0.26.0 openai-1.12.0\n"]}]},{"cell_type":"code","execution_count":4,"metadata":{"id":"nl3GO59eCxQv","executionInfo":{"status":"ok","timestamp":1707619605853,"user_tz":-480,"elapsed":6067,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"outputs":[],"source":["from IPython.display import display, Image\n","from openai import OpenAI\n","import os\n","import pandas as pd\n","import json\n","import io\n","from PIL import Image\n","import requests\n","\n","from google.colab import userdata\n","API_key=userdata.get('OpenAI-Key')\n","\n","client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", API_key))\n","\n","#Lets import some helper functions for assistants from https://cookbook.openai.com/examples/assistants_api_overview_python\n","def show_json(obj):\n"," display(json.loads(obj.model_dump_json()))\n","\n","def submit_message(assistant_id, thread, user_message,file_ids=None):\n"," params = {\n"," 'thread_id': thread.id,\n"," 'role': 'user',\n"," 'content': user_message,\n"," }\n"," if file_ids:\n"," params['file_ids']=file_ids\n","\n"," client.beta.threads.messages.create(\n"," **params\n",")\n"," return client.beta.threads.runs.create(\n"," thread_id=thread.id,\n"," assistant_id=assistant_id,\n",")\n","\n","def get_response(thread):\n"," return client.beta.threads.messages.list(thread_id=thread.id)"]},{"cell_type":"code","source":["financial_data_path = 'NotRealCorp_financial_data.json'\n","financial_data = pd.read_json(financial_data_path)\n","financial_data.head(5)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":206},"id":"Mv4kJr1sEAPq","executionInfo":{"status":"ok","timestamp":1707619718449,"user_tz":-480,"elapsed":5,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"3197abb8-ac8d-4400-859c-5502df91c76e"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":[" Year Quarter Distribution channel Revenue ($M) Costs ($M) \\\n","0 2021 Q1 Online Sales 1.50 1.301953 \n","1 2021 Q1 Direct Sales 1.50 1.380809 \n","2 2021 Q1 Retail Partners 1.50 1.348246 \n","3 2021 Q2 Online Sales 1.52 1.308608 \n","4 2021 Q2 Direct Sales 1.52 1.413305 \n","\n"," Customer count Time \n","0 150 2021 Q1 \n","1 151 2021 Q1 \n","2 152 2021 Q1 \n","3 152 2021 Q2 \n","4 153 2021 Q2 "],"text/html":["\n"," <div id=\"df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4\" class=\"colab-df-container\">\n"," <div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>Year</th>\n"," <th>Quarter</th>\n"," <th>Distribution channel</th>\n"," <th>Revenue ($M)</th>\n"," <th>Costs ($M)</th>\n"," <th>Customer count</th>\n"," <th>Time</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>0</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Online Sales</td>\n"," <td>1.50</td>\n"," <td>1.301953</td>\n"," <td>150</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>1</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Direct Sales</td>\n"," <td>1.50</td>\n"," <td>1.380809</td>\n"," <td>151</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>2</th>\n"," <td>2021</td>\n"," <td>Q1</td>\n"," <td>Retail Partners</td>\n"," <td>1.50</td>\n"," <td>1.348246</td>\n"," <td>152</td>\n"," <td>2021 Q1</td>\n"," </tr>\n"," <tr>\n"," <th>3</th>\n"," <td>2021</td>\n"," <td>Q2</td>\n"," <td>Online Sales</td>\n"," <td>1.52</td>\n"," <td>1.308608</td>\n"," <td>152</td>\n"," <td>2021 Q2</td>\n"," </tr>\n"," <tr>\n"," <th>4</th>\n"," <td>2021</td>\n"," <td>Q2</td>\n"," <td>Direct Sales</td>\n"," <td>1.52</td>\n"," <td>1.413305</td>\n"," <td>153</td>\n"," <td>2021 Q2</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>\n"," <div class=\"colab-df-buttons\">\n","\n"," <div class=\"colab-df-container\">\n"," <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4')\"\n"," title=\"Convert this dataframe to an interactive table.\"\n"," style=\"display:none;\">\n","\n"," <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n"," <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n"," </svg>\n"," </button>\n","\n"," <style>\n"," .colab-df-container {\n"," display:flex;\n"," gap: 12px;\n"," }\n","\n"," .colab-df-convert {\n"," background-color: #E8F0FE;\n"," border: none;\n"," border-radius: 50%;\n"," cursor: pointer;\n"," display: none;\n"," fill: #1967D2;\n"," height: 32px;\n"," padding: 0 0 0 0;\n"," width: 32px;\n"," }\n","\n"," .colab-df-convert:hover {\n"," background-color: #E2EBFA;\n"," box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n"," fill: #174EA6;\n"," }\n","\n"," .colab-df-buttons div {\n"," margin-bottom: 4px;\n"," }\n","\n"," [theme=dark] .colab-df-convert {\n"," background-color: #3B4455;\n"," fill: #D2E3FC;\n"," }\n","\n"," [theme=dark] .colab-df-convert:hover {\n"," background-color: #434B5C;\n"," box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n"," filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n"," fill: #FFFFFF;\n"," }\n"," </style>\n","\n"," <script>\n"," const buttonEl =\n"," document.querySelector('#df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4 button.colab-df-convert');\n"," buttonEl.style.display =\n"," google.colab.kernel.accessAllowed ? 'block' : 'none';\n","\n"," async function convertToInteractive(key) {\n"," const element = document.querySelector('#df-b3640cdb-6261-48f0-87e0-bb3e03ceb4e4');\n"," const dataTable =\n"," await google.colab.kernel.invokeFunction('convertToInteractive',\n"," [key], {});\n"," if (!dataTable) return;\n","\n"," const docLinkHtml = 'Like what you see? Visit the ' +\n"," '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n"," + ' to learn more about interactive tables.';\n"," element.innerHTML = '';\n"," dataTable['output_type'] = 'display_data';\n"," await google.colab.output.renderOutput(dataTable, element);\n"," const docLink = document.createElement('div');\n"," docLink.innerHTML = docLinkHtml;\n"," element.appendChild(docLink);\n"," }\n"," </script>\n"," </div>\n","\n","\n","<div id=\"df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6\">\n"," <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6')\"\n"," title=\"Suggest charts\"\n"," style=\"display:none;\">\n","\n","<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n"," width=\"24px\">\n"," <g>\n"," <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n"," </g>\n","</svg>\n"," </button>\n","\n","<style>\n"," .colab-df-quickchart {\n"," --bg-color: #E8F0FE;\n"," --fill-color: #1967D2;\n"," --hover-bg-color: #E2EBFA;\n"," --hover-fill-color: #174EA6;\n"," --disabled-fill-color: #AAA;\n"," --disabled-bg-color: #DDD;\n"," }\n","\n"," [theme=dark] .colab-df-quickchart {\n"," --bg-color: #3B4455;\n"," --fill-color: #D2E3FC;\n"," --hover-bg-color: #434B5C;\n"," --hover-fill-color: #FFFFFF;\n"," --disabled-bg-color: #3B4455;\n"," --disabled-fill-color: #666;\n"," }\n","\n"," .colab-df-quickchart {\n"," background-color: var(--bg-color);\n"," border: none;\n"," border-radius: 50%;\n"," cursor: pointer;\n"," display: none;\n"," fill: var(--fill-color);\n"," height: 32px;\n"," padding: 0;\n"," width: 32px;\n"," }\n","\n"," .colab-df-quickchart:hover {\n"," background-color: var(--hover-bg-color);\n"," box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n"," fill: var(--button-hover-fill-color);\n"," }\n","\n"," .colab-df-quickchart-complete:disabled,\n"," .colab-df-quickchart-complete:disabled:hover {\n"," background-color: var(--disabled-bg-color);\n"," fill: var(--disabled-fill-color);\n"," box-shadow: none;\n"," }\n","\n"," .colab-df-spinner {\n"," border: 2px solid var(--fill-color);\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," animation:\n"," spin 1s steps(1) infinite;\n"," }\n","\n"," @keyframes spin {\n"," 0% {\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," border-left-color: var(--fill-color);\n"," }\n"," 20% {\n"," border-color: transparent;\n"," border-left-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," }\n"," 30% {\n"," border-color: transparent;\n"," border-left-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," border-right-color: var(--fill-color);\n"," }\n"," 40% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," border-top-color: var(--fill-color);\n"," }\n"," 60% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," }\n"," 80% {\n"," border-color: transparent;\n"," border-right-color: var(--fill-color);\n"," border-bottom-color: var(--fill-color);\n"," }\n"," 90% {\n"," border-color: transparent;\n"," border-bottom-color: var(--fill-color);\n"," }\n"," }\n","</style>\n","\n"," <script>\n"," async function quickchart(key) {\n"," const quickchartButtonEl =\n"," document.querySelector('#' + key + ' button');\n"," quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n"," quickchartButtonEl.classList.add('colab-df-spinner');\n"," try {\n"," const charts = await google.colab.kernel.invokeFunction(\n"," 'suggestCharts', [key], {});\n"," } catch (error) {\n"," console.error('Error during call to suggestCharts:', error);\n"," }\n"," quickchartButtonEl.classList.remove('colab-df-spinner');\n"," quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n"," }\n"," (() => {\n"," let quickchartButtonEl =\n"," document.querySelector('#df-3e6a4d79-2d22-42d5-b89b-e8cbfa750de6 button');\n"," quickchartButtonEl.style.display =\n"," google.colab.kernel.accessAllowed ? 'block' : 'none';\n"," })();\n"," </script>\n","</div>\n","\n"," </div>\n"," </div>\n"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["file = client.files.create(\n"," file=open('NotRealCorp_financial_data.json',\"rb\"),\n"," purpose='assistants',\n",")"],"metadata":{"id":"naRttfJsEcbR","executionInfo":{"status":"ok","timestamp":1707619780967,"user_tz":-480,"elapsed":1034,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["assistant = client.beta.assistants.create(\n"," instructions=\"You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization\",\n"," model=\"gpt-4-1106-preview\",\n"," tools=[{\"type\": \"code_interpreter\"}],\n"," file_ids=[file.id]\n",")"],"metadata":{"id":"bv21EWvKEjWP","executionInfo":{"status":"ok","timestamp":1707619796761,"user_tz":-480,"elapsed":511,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["thread = client.beta.threads.create(\n"," messages=[\n"," {\n"," \"role\": \"user\",\n"," \"content\": \"Calculate profit (revenue minus cost) by quarter and year, and visualize as a line plot across the distribution channels, where the colors of the lines are green, light red, and light blue\",\n"," \"file_ids\": [file.id]\n"," }\n"," ]\n",")"],"metadata":{"id":"e9od4SDmEk_b","executionInfo":{"status":"ok","timestamp":1707619809125,"user_tz":-480,"elapsed":527,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["run = client.beta.threads.runs.create(\n"," thread_id=thread.id,\n"," assistant_id=assistant.id,\n",")"],"metadata":{"id":"yZ2zuB63Entp","executionInfo":{"status":"ok","timestamp":1707619818426,"user_tz":-480,"elapsed":504,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["messages = client.beta.threads.messages.list(thread_id=thread.id)"],"metadata":{"id":"J8laxybUErur","executionInfo":{"status":"ok","timestamp":1707619837176,"user_tz":-480,"elapsed":512,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["import time\n","\n","while True:\n"," messages = client.beta.threads.messages.list(thread_id=thread.id)\n"," try:\n"," #See if image has been created\n"," messages.data[0].content[0].image_file\n"," #Sleep to make sure run has completed\n"," time.sleep(5)\n"," print('Plot created!')\n"," break\n"," except:\n"," time.sleep(10)\n"," print('Assistant still working...')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dsUwZKk6Eu7o","executionInfo":{"status":"ok","timestamp":1707619978003,"user_tz":-480,"elapsed":130282,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"75c2d621-24a1-44ee-a011-844a6f39ca31"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Assistant still working...\n","Plot created!\n"]}]},{"cell_type":"code","source":["messages = client.beta.threads.messages.list(thread_id=thread.id)\n","[message.content[0] for message in messages.data]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PJeTKUdGEz_1","executionInfo":{"status":"ok","timestamp":1707620018940,"user_tz":-480,"elapsed":522,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"1c790ac4-4c14-4025-bce5-e5a16ae96744"},"execution_count":15,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[MessageContentImageFile(image_file=ImageFile(file_id='file-0TaYohpWFuqF4SOCMhuy2ujI'), type='image_file'),\n"," MessageContentText(text=Text(annotations=[], value=\"It seems I made an error when attempting to drop the 'Year' and 'Quarter' columns from the pivot table's columns MultiIndex. I need to address this by correctly referencing these columns in the context of a MultiIndex. Let's correct this and try plotting again.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It appears there is an issue with plotting the multi-index ('Year', 'Quarter') directly. To resolve this, we need to convert the multi-index into a single string that represents both Year and Quarter. I will adjust the DataFrame index and attempt the visualization again.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='With the profit calculated and data grouped by \"Year\", \"Quarter\", and \"Distribution channel\", we can now move on to visualizing this information as a line plot.\\n\\nWe will use three different colors for the lines representing different distribution channels as specified:\\n\\n- Online Sales: Light Blue\\n- Direct Sales: Light Red\\n- Retail Partners: Green\\n\\nLet\\'s proceed with the visualization.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='We have successfully structured the data into a pandas DataFrame. Now we can proceed to calculate the profit by subtracting the \"Costs ($M)\" column from the \"Revenue ($M)\" column, and then group and visualize the data by year, quarter, and distribution channel. Let\\'s start with the calculation and grouping.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='The content of the \"Year\" key appears to be a dictionary where each entry corresponds to a row index and the respective year. This suggests that the keys of the main JSON object represent columns that are in dictionaries themselves with row index as keys.\\n\\nI will now construct a pandas DataFrame by iterating over these keys and combining the corresponding dictionaries into a structured table.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='The JSON object contains keys corresponding to what seems like different data columns like \"Year\", \"Quarter\", \"Distribution channel\", \"Revenue\", \"Costs\", \"Customer count\", and a combined \"Time\" column.\\n\\nIt might be a bit tricky to deal with this structure directly, so let\\'s see if we can correctly parse and normalize this JSON by columns into a pandas DataFrame. I will explore one of the columns to see its data format.'), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It looks like the data has been loaded incorrectly, resulting in an odd DataFrame structure that seems to include serialized JSON. I will try to unpack this serialized JSON data correctly into a usable pandas DataFrame. Let's proceed by examining the keys in the JSON object so we can determine the correct approach to normalize the data into a DataFrame.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"It appears that the attempt to read the file as a typical CSV or Excel file has failed, resulting in an empty DataFrame with an unusual column structure that suggests the data might be stored in a JSON format within the file.\\n\\nI'll now attempt to load the data as JSON to see if that correctly interprets the file structure.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value=\"To calculate profit by quarter and year and visualize it across distribution channels as requested, I first need to inspect the contents of the uploaded file to understand its structure (column names, data types, etc.).\\n\\nLet's begin by loading the file and taking a look at the first few rows to determine how to proceed.\"), type='text'),\n"," MessageContentText(text=Text(annotations=[], value='Calculate profit (revenue minus cost) by quarter and year, and visualize as a line plot across the distribution channels, where the colors of the lines are green, light red, and light blue'), type='text')]"]},"metadata":{},"execution_count":15}]},{"cell_type":"code","source":["# Quick helper function to convert our output file to a png\n","def convert_file_to_png(file_id, write_path):\n"," data = client.files.content(file_id)\n"," data_bytes = data.read()\n"," with open(write_path, \"wb\") as file:\n"," file.write(data_bytes)"],"metadata":{"id":"CWyogxdhE5Xl","executionInfo":{"status":"ok","timestamp":1707620019701,"user_tz":-480,"elapsed":4,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["plot_file_id = messages.data[0].content[0].image_file.file_id\n","image_path = \"NotRealCorp_chart.png\"\n","convert_file_to_png(plot_file_id,image_path)\n","\n","#Upload\n","plot_file = client.files.create(\n"," file=open(image_path, \"rb\"),\n"," purpose='assistants'\n",")"],"metadata":{"id":"y7x0uRvNE6y-","executionInfo":{"status":"ok","timestamp":1707620195449,"user_tz":-480,"elapsed":2152,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["submit_message(assistant.id,thread,\"Give me two medium length sentences (~20-30 words per sentence) of the \\\n"," most important insights from the plot you just created.\\\n"," These will be used for a slide deck, and they should be about the\\\n"," 'so what' behind the data.\"\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"V_ITOsiCE98C","executionInfo":{"status":"ok","timestamp":1707620206406,"user_tz":-480,"elapsed":1026,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"8c169d2e-9c74-477d-e90d-7aaaf7cf0e78"},"execution_count":19,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_GWFPQYRVdxr0LVyFQWOgy0mB', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620205, expires_at=1707620805, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":19}]},{"cell_type":"code","source":["# Hard coded wait for a response, as the assistant may iterate on the bullets.\n","time.sleep(10)\n","response = get_response(thread)\n","bullet_points = response.data[0].content[0].text.value\n","print(bullet_points)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"A1UdmH-nE_wK","executionInfo":{"status":"ok","timestamp":1707620227601,"user_tz":-480,"elapsed":11415,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"095e74a5-ff2c-4ea3-87f8-8d9f647eed91"},"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["The line plot illustrates that Online Sales consistently generates the highest profit across all quarters, indicating a strong market presence and effective online strategies. Despite seasonal fluctuations, Retail Partners show a stable profit trend, while Direct Sales appear to be the least profitable channel with more pronounced variances over time.\n"]}]},{"cell_type":"code","source":["submit_message(assistant.id,thread,\"Given the plot and bullet points you created,\\\n"," come up with a very brief title for a slide. It should reflect just the main insights you came up with.\"\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MrNeJfmkFBmI","executionInfo":{"status":"ok","timestamp":1707620228114,"user_tz":-480,"elapsed":514,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"ca5ca6c7-d057-4303-c5a2-775124525d76"},"execution_count":21,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_axOn3g0XObuWLz01NBb8RJWa', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620227, expires_at=1707620827, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":21}]},{"cell_type":"code","source":["#Wait as assistant may take a few steps\n","time.sleep(10)\n","response = get_response(thread)\n","title = response.data[0].content[0].text.value\n","print(title)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"UPlwpe3UFDk9","executionInfo":{"status":"ok","timestamp":1707620238328,"user_tz":-480,"elapsed":10216,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"65565240-b594-4dbd-e3c4-83b5a9339446"},"execution_count":22,"outputs":[{"output_type":"stream","name":"stdout","text":["\"Online Sales Lead Profits Across Channels with Steady Retail Partner Performance\"\n"]}]},{"cell_type":"code","source":["company_summary = \"NotReal Corp is a prominent hardware company that manufactures and sells processors, graphics cards and other essential computer hardware.\"\n","response = client.images.generate(\n"," model='dall-e-3',\n"," prompt=f\"given this company summary {company_summary}, create an inspirational \\\n"," photo showing the growth and path forward. This will be used at a quarterly\\\n"," financial planning meeting\",\n"," size=\"1024x1024\",\n"," quality=\"hd\",\n"," n=1\n",")\n","image_url = response.data[0].url"],"metadata":{"id":"0nWlwWIqFEsT","executionInfo":{"status":"ok","timestamp":1707620253115,"user_tz":-480,"elapsed":14790,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":23,"outputs":[]},{"cell_type":"code","source":["dalle_img_path = 'dalle_image.png'\n","img = requests.get(image_url)\n","\n","#Save locally\n","with open(dalle_img_path,'wb') as file:\n"," file.write(img.content)\n","\n","#Upload\n","dalle_file = client.files.create(\n"," file=open(dalle_img_path, \"rb\"),\n"," purpose='assistants'\n",")"],"metadata":{"id":"2562HE27FJeF","executionInfo":{"status":"ok","timestamp":1707620275055,"user_tz":-480,"elapsed":6318,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":25,"outputs":[]},{"cell_type":"code","source":["title_template = \"\"\"\n","from pptx import Presentation\n","from pptx.util import Inches, Pt\n","from pptx.enum.text import PP_PARAGRAPH_ALIGNMENT\n","from pptx.dml.color import RGBColor\n","\n","# Create a new presentation object\n","prs = Presentation()\n","\n","# Add a blank slide layout\n","blank_slide_layout = prs.slide_layouts[6]\n","slide = prs.slides.add_slide(blank_slide_layout)\n","\n","# Set the background color of the slide to black\n","background = slide.background\n","fill = background.fill\n","fill.solid()\n","fill.fore_color.rgb = RGBColor(0, 0, 0)\n","\n","# Add image to the left side of the slide with a margin at the top and bottom\n","left = Inches(0)\n","top = Inches(0)\n","height = prs.slide_height\n","width = prs.slide_width * 3/5\n","pic = slide.shapes.add_picture(image_path, left, top, width=width, height=height)\n","\n","# Add title text box positioned higher\n","left = prs.slide_width * 3/5\n","top = Inches(2)\n","width = prs.slide_width * 2/5\n","height = Inches(1)\n","title_box = slide.shapes.add_textbox(left, top, width, height)\n","title_frame = title_box.text_frame\n","title_p = title_frame.add_paragraph()\n","title_p.text = title_text\n","title_p.font.bold = True\n","title_p.font.size = Pt(38)\n","title_p.font.color.rgb = RGBColor(255, 255, 255)\n","title_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\n","# Add subtitle text box\n","left = prs.slide_width * 3/5\n","top = Inches(3)\n","width = prs.slide_width * 2/5\n","height = Inches(1)\n","subtitle_box = slide.shapes.add_textbox(left, top, width, height)\n","subtitle_frame = subtitle_box.text_frame\n","subtitle_p = subtitle_frame.add_paragraph()\n","subtitle_p.text = subtitle_text\n","subtitle_p.font.size = Pt(22)\n","subtitle_p.font.color.rgb = RGBColor(255, 255, 255)\n","subtitle_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\"\"\"\n","\n","data_vis_template = \"\"\"\n","from pptx import Presentation\n","from pptx.util import Inches, Pt\n","from pptx.enum.text import PP_PARAGRAPH_ALIGNMENT\n","from pptx.dml.color import RGBColor\n","\n","# Create a new presentation object\n","prs = Presentation()\n","\n","# Add a blank slide layout\n","blank_slide_layout = prs.slide_layouts[6]\n","slide = prs.slides.add_slide(blank_slide_layout)\n","\n","# Set the background color of the slide to black\n","background = slide.background\n","fill = background.fill\n","fill.solid()\n","fill.fore_color.rgb = RGBColor(0, 0, 0)\n","\n","# Define placeholders\n","image_path = data_vis_img\n","title_text = \"Maximizing Profits: The Dominance of Online Sales & Direct Sales Optimization\"\n","bullet_points = \"• Online Sales consistently lead in profitability across quarters, indicating a strong digital market presence.\\n• Direct Sales show fluctuations, suggesting variable performance and the need for targeted improvements in that channel.\"\n","\n","# Add image placeholder on the left side of the slide\n","left = Inches(0.2)\n","top = Inches(1.8)\n","height = prs.slide_height - Inches(3)\n","width = prs.slide_width * 3/5\n","pic = slide.shapes.add_picture(image_path, left, top, width=width, height=height)\n","\n","# Add title text spanning the whole width\n","left = Inches(0)\n","top = Inches(0)\n","width = prs.slide_width\n","height = Inches(1)\n","title_box = slide.shapes.add_textbox(left, top, width, height)\n","title_frame = title_box.text_frame\n","title_frame.margin_top = Inches(0.1)\n","title_p = title_frame.add_paragraph()\n","title_p.text = title_text\n","title_p.font.bold = True\n","title_p.font.size = Pt(28)\n","title_p.font.color.rgb = RGBColor(255, 255, 255)\n","title_p.alignment = PP_PARAGRAPH_ALIGNMENT.CENTER\n","\n","# Add hardcoded \"Key Insights\" text and bullet points\n","left = prs.slide_width * 2/3\n","top = Inches(1.5)\n","width = prs.slide_width * 1/3\n","height = Inches(4.5)\n","insights_box = slide.shapes.add_textbox(left, top, width, height)\n","insights_frame = insights_box.text_frame\n","insights_p = insights_frame.add_paragraph()\n","insights_p.text = \"Key Insights:\"\n","insights_p.font.bold = True\n","insights_p.font.size = Pt(24)\n","insights_p.font.color.rgb = RGBColor(0, 128, 100)\n","insights_p.alignment = PP_PARAGRAPH_ALIGNMENT.LEFT\n","insights_frame.add_paragraph()\n","\n","\n","bullet_p = insights_frame.add_paragraph()\n","bullet_p.text = bullet_points\n","bullet_p.font.size = Pt(12)\n","bullet_p.font.color.rgb = RGBColor(255, 255, 255)\n","bullet_p.line_spacing = 1.5\n","\"\"\""],"metadata":{"id":"0i-dDVFkFOaO","executionInfo":{"status":"ok","timestamp":1707620278192,"user_tz":-480,"elapsed":502,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":26,"outputs":[]},{"cell_type":"code","source":["title_text = \"NotRealCorp\"\n","subtitle_text = \"Quarterly financial planning meeting, Q3 2023\""],"metadata":{"id":"E5prceoBGQcz","executionInfo":{"status":"ok","timestamp":1707620339642,"user_tz":-480,"elapsed":2,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":28,"outputs":[]},{"cell_type":"code","source":["submit_message(assistant.id,thread,f\"Use the included code template to create a PPTX slide that follows the template format, but uses the image, company name/title, and document name/subtitle included:\\\n","{title_template}. IMPORTANT: Use the image file included in this message as the image_path image in this first slide, and use the Company Name {title_text} as the title_text variable, and \\\n"," use the subtitle_text {subtitle_text} a the subtitle_text variable. \\\n"," NEST, create a SECOND slide using the following code template: {data_vis_template} to create a PPTX slide that follows the template format, but uses the company name/title, and document name/subtitle included:\\\n","{data_vis_template}. IMPORTANT: Use the line plot image, that is the second attached image in this message, that you created earlier in the thread as the data_vis_img image, and use the data visualization title that you created earlier for the variable title_text, and\\\n"," the bullet points of insights you created earlier for the bullet_points variable. Output these TWO SLIDES as a .pptx file. Make sure the output is two slides, with each slide matching the respective template given in this message.\",\n"," file_ids=[dalle_file.id, plot_file.id]\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NAHKthE5Ga02","executionInfo":{"status":"ok","timestamp":1707620340656,"user_tz":-480,"elapsed":1016,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"33508da3-71e5-4df7-9ea8-26e6eae66f21"},"execution_count":29,"outputs":[{"output_type":"execute_result","data":{"text/plain":["Run(id='run_eQvuBgDgLPEBDCE3C38UAnUM', assistant_id='asst_s0WPNCpicIs8tCxIn2ILGwu1', cancelled_at=None, completed_at=None, created_at=1707620339, expires_at=1707620939, failed_at=None, file_ids=['file-76sx64aGS7mobabDw6W3WNPe'], instructions='You are a data scientist assistant. When given data and a query, write the proper code and create the proper visualization', last_error=None, metadata={}, model='gpt-4-1106-preview', object='thread.run', required_action=None, started_at=None, status='queued', thread_id='thread_p3iCATH9IfCJKZyfcBxoPjjV', tools=[ToolAssistantToolsCode(type='code_interpreter')], usage=None)"]},"metadata":{},"execution_count":29}]},{"cell_type":"code","source":["#May take 1-3 mins\n","while True:\n"," try:\n"," response = get_response(thread)\n"," pptx_id = response.data[0].content[0].text.annotations[0].file_path.file_id\n"," print(\"Successfully retrieved pptx_id:\", pptx_id)\n"," break\n"," except Exception as e:\n"," print(\"Assistant still working on PPTX...\")\n"," time.sleep(10)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XIPEkk4TGbMS","executionInfo":{"status":"ok","timestamp":1707620436104,"user_tz":-480,"elapsed":95450,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}},"outputId":"4f281d7b-019d-410d-9218-60ca56479b64"},"execution_count":30,"outputs":[{"output_type":"stream","name":"stdout","text":["Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Assistant still working on PPTX...\n","Successfully retrieved pptx_id: file-NwArFbjmmFw0tnboToe6f18N\n"]}]},{"cell_type":"code","source":["pptx_id = response.data[0].content[0].text.annotations[0].file_path.file_id\n","ppt_file= client.files.content(pptx_id)\n","file_obj = io.BytesIO(ppt_file.read())\n","with open(\"created_slides.pptx\", \"wb\") as f:\n"," f.write(file_obj.getbuffer())"],"metadata":{"id":"CtqZHLolGbyT","executionInfo":{"status":"ok","timestamp":1707620446153,"user_tz":-480,"elapsed":1151,"user":{"displayName":"Liu LingHui","userId":"15918876591374452686"}}},"execution_count":32,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"MxJ_Co2ZGoxf"},"execution_count":null,"outputs":[]}]}
\ No newline at end of file
{
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to use the DALL·E API\n",
"\n",
"This notebook shows how to use OpenAI's DALL·E image API endpoints.\n",
"\n",
"There are three API endpoints:\n",
"- **Generations:** generates an image or images based on an input caption\n",
"- **Edits:** edits or extends an existing image\n",
"- **Variations:** generates variations of an input image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"- Import the packages you'll need\n",
"- Import your OpenAI API key: You can do this by running \\``export OPENAI_API_KEY=\"your API key\"`\\` in your terminal.\n",
"- Set a directory to save images to"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"from openai import OpenAI # OpenAI Python library to make API calls\n",
"import requests # used to download images\n",
"import os # used to access filepaths\n",
"from PIL import Image # used to print and edit images\n",
"\n",
"# initialize OpenAI client\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"image_dir='./images'\n"
]
}
],
"source": [
"# set a directory to save DALL·E images to\n",
"image_dir_name = \"images\"\n",
"image_dir = os.path.join(os.curdir, image_dir_name)\n",
"\n",
"# create the directory if it doesn't yet exist\n",
"if not os.path.isdir(image_dir):\n",
" os.mkdir(image_dir)\n",
"\n",
"# print the directory to save to\n",
"print(f\"{image_dir=}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generations\n",
"\n",
"The generation API endpoint creates an image based on a text prompt. [API Reference](https://platform.openai.com/docs/api-reference/images/create)\n",
"\n",
"**Required inputs:**\n",
"- `prompt` (str): A text description of the desired image(s). The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3.\n",
"\n",
"**Optional inputs:**\n",
"- `model` (str): The model to use for image generation. Defaults to dall-e-2\n",
"- `n` (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.\n",
"- `quality` (str): The quality of the image that will be generated. hd creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3.\n",
"- `response_format` (str): The format in which the generated images are returned. Must be one of \"url\" or \"b64_json\". Defaults to \"url\".\n",
"- `size` (str): The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models. Defaults to \"1024x1024\".\n",
"- `style`(str | null): The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for dall-e-3.\n",
"- `user` (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more.](https://beta.openai.com/docs/usage-policies/end-user-ids)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ImagesResponse(created=1701994117, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-9HXYFy8ux4r6aboFyec2OLRf/user-8OA8IvMYkfdAcUZXgzAXHS7d/img-ced13hkOk3lXkccQgW1fAQjm.png?st=2023-12-07T23%3A08%3A37Z&se=2023-12-08T01%3A08%3A37Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2023-12-07T16%3A41%3A48Z&ske=2023-12-08T16%3A41%3A48Z&sks=b&skv=2021-08-06&sig=tcD0iyU0ABOvWAKsY89gp5hLVIYkoSXQnrcmH%2Brkric%3D')])\n"
]
}
],
"source": [
"# create an image\n",
"\n",
"# set the prompt\n",
"prompt = \"A cyberpunk monkey hacker dreaming of a beautiful bunch of bananas, digital art\"\n",
"\n",
"# call the OpenAI API\n",
"generation_response = client.images.generate(\n",
" model = \"dall-e-3\"\n",
" prompt=prompt,\n",
" n=1,\n",
" size=\"1024x1024\",\n",
" response_format=\"url\",\n",
")\n",
"\n",
"# print response\n",
"print(generation_response)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# save the image\n",
"generated_image_name = \"generated_image.png\" # any name you like; the filetype should be .png\n",
"generated_image_filepath = os.path.join(image_dir, generated_image_name)\n",
"generated_image_url = generation_response.data[0].url # extract image URL from response\n",
"generated_image = requests.get(generated_image_url).content # download the image\n",
"\n",
"with open(generated_image_filepath, \"wb\") as image_file:\n",
" image_file.write(generated_image) # write the image to the file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the image\n",
"print(generated_image_filepath)\n",
"display(Image.open(generated_image_filepath))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Variations\n",
"\n",
"The variations endpoint generates new images (variations) similar to an input image. [API Reference](https://platform.openai.com/docs/api-reference/images/createVariation)\n",
"\n",
"Here we'll generate variations of the image generated above.\n",
"\n",
"**Required inputs:**\n",
"- `image` (str): The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square.\n",
"\n",
"**Optional inputs:**\n",
"- `model` (str): The model to use for image variations. Only dall-e-2 is supported at this time.\n",
"- `n` (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.\n",
"- `size` (str): The size of the generated images. Must be one of \"256x256\", \"512x512\", or \"1024x1024\". Smaller images are faster. Defaults to \"1024x1024\".\n",
"- `response_format` (str): The format in which the generated images are returned. Must be one of \"url\" or \"b64_json\". Defaults to \"url\".\n",
"- `user` (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more.](https://beta.openai.com/docs/usage-policies/end-user-ids)\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ImagesResponse(created=1701994139, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-9HXYFy8ux4r6aboFyec2OLRf/user-8OA8IvMYkfdAcUZXgzAXHS7d/img-noNRGgwaaotRGIe6Y2GVeSpr.png?st=2023-12-07T23%3A08%3A59Z&se=2023-12-08T01%3A08%3A59Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2023-12-07T16%3A39%3A11Z&ske=2023-12-08T16%3A39%3A11Z&sks=b&skv=2021-08-06&sig=ER5RUglhtIk9LWJXw1DsolorT4bnEmFostfnUjY21ns%3D'), Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-9HXYFy8ux4r6aboFyec2OLRf/user-8OA8IvMYkfdAcUZXgzAXHS7d/img-oz952tL11FFhf9iXXJVIRUZX.png?st=2023-12-07T23%3A08%3A59Z&se=2023-12-08T01%3A08%3A59Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2023-12-07T16%3A39%3A11Z&ske=2023-12-08T16%3A39%3A11Z&sks=b&skv=2021-08-06&sig=99rJOQwDKsfIeerlMXMHholhAhrHfYaQRFJBF8FKv74%3D')])\n"
]
}
],
"source": [
"# create variations\n",
"\n",
"# call the OpenAI API, using `create_variation` rather than `create`\n",
"variation_response = client.images.create_variation(\n",
" image=generated_image, # generated_image is the image generated above\n",
" n=2,\n",
" size=\"1024x1024\",\n",
" response_format=\"url\",\n",
")\n",
"\n",
"# print response\n",
"print(variation_response)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# save the images\n",
"variation_urls = [datum.url for datum in variation_response.data] # extract URLs\n",
"variation_images = [requests.get(url).content for url in variation_urls] # download images\n",
"variation_image_names = [f\"variation_image_{i}.png\" for i in range(len(variation_images))] # create names\n",
"variation_image_filepaths = [os.path.join(image_dir, name) for name in variation_image_names] # create filepaths\n",
"for image, filepath in zip(variation_images, variation_image_filepaths): # loop through the variations\n",
" with open(filepath, \"wb\") as image_file: # open the file\n",
" image_file.write(image) # write the image to the file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the original image\n",
"print(generated_image_filepath)\n",
"display(Image.open(generated_image_filepath))\n",
"\n",
"# print the new variations\n",
"for variation_image_filepaths in variation_image_filepaths:\n",
" print(variation_image_filepaths)\n",
" display(Image.open(variation_image_filepaths))\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Edits\n",
"\n",
"The edit endpoint uses DALL·E to generate a specified portion of an existing image. Three inputs are needed: the image to edit, a mask specifying the portion to be regenerated, and a prompt describing the desired image. [API Reference](https://platform.openai.com/docs/api-reference/images/createEdit)\n",
"\n",
"**Required inputs:** \n",
"- `image` (str): The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask.\n",
"- `prompt` (str): A text description of the desired image(s). The maximum length is 1000 characters.\n",
"\n",
"**Optional inputs:**\n",
"- `mask` (file): An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as image.\n",
"- `model` (str): The model to use for edit image. Only dall-e-2 is supported at this time.\n",
"- `n` (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.\n",
"- `size` (str): The size of the generated images. Must be one of \"256x256\", \"512x512\", or \"1024x1024\". Smaller images are faster. Defaults to \"1024x1024\".\n",
"- `response_format` (str): The format in which the generated images are returned. Must be one of \"url\" or \"b64_json\". Defaults to \"url\".\n",
"- `user` (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more.](https://beta.openai.com/docs/usage-policies/end-user-ids)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set Edit Area\n",
"\n",
"An edit requires a \"mask\" to specify which portion of the image to regenerate. Any pixel with an alpha of 0 (transparent) will be regenerated. The code below creates a 1024x1024 mask where the bottom half is transparent."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# create a mask\n",
"width = 1024\n",
"height = 1024\n",
"mask = Image.new(\"RGBA\", (width, height), (0, 0, 0, 1)) # create an opaque image mask\n",
"\n",
"# set the bottom half to be transparent\n",
"for x in range(width):\n",
" for y in range(height // 2, height): # only loop over the bottom half of the mask\n",
" # set alpha (A) to zero to turn pixel transparent\n",
" alpha = 0\n",
" mask.putpixel((x, y), (0, 0, 0, alpha))\n",
"\n",
"# save the mask\n",
"mask_name = \"bottom_half_mask.png\"\n",
"mask_filepath = os.path.join(image_dir, mask_name)\n",
"mask.save(mask_filepath)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Perform Edit\n",
"\n",
"Now we supply our image, caption and mask to the API to get 5 examples of edits to our image"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ImagesResponse(created=1701994167, data=[Image(b64_json=None, revised_prompt=None, url='https://oaidalleapiprodscus.blob.core.windows.net/private/org-9HXYFy8ux4r6aboFyec2OLRf/user-8OA8IvMYkfdAcUZXgzAXHS7d/img-9UOVGC7wB8MS2Q7Rwgj0fFBq.png?st=2023-12-07T23%3A09%3A27Z&se=2023-12-08T01%3A09%3A27Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2023-12-07T16%3A40%3A37Z&ske=2023-12-08T16%3A40%3A37Z&sks=b&skv=2021-08-06&sig=MsRMZ1rN434bVdWr%2B9kIoqu9CIrvZypZBfkQPTOhCl4%3D')])\n"
]
}
],
"source": [
"# edit an image\n",
"\n",
"# call the OpenAI API\n",
"edit_response = client.images.edit(\n",
" image=open(generated_image_filepath, \"rb\"), # from the generation section\n",
" mask=open(mask_filepath, \"rb\"), # from right above\n",
" prompt=prompt, # from the generation section\n",
" n=1,\n",
" size=\"1024x1024\",\n",
" response_format=\"url\",\n",
")\n",
"\n",
"# print response\n",
"print(edit_response)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# save the image\n",
"edited_image_name = \"edited_image.png\" # any name you like; the filetype should be .png\n",
"edited_image_filepath = os.path.join(image_dir, edited_image_name)\n",
"edited_image_url = edit_response.data[0].url # extract image URL from response\n",
"edited_image = requests.get(edited_image_url).content # download the image\n",
"\n",
"with open(edited_image_filepath, \"wb\") as image_file:\n",
" image_file.write(edited_image) # write the image to the file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the original image\n",
"print(generated_image_filepath)\n",
"display(Image.open(generated_image_filepath))\n",
"\n",
"# print edited image\n",
"print(edited_image_filepath)\n",
"display(Image.open(edited_image_filepath))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 ('openai')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPWNySDXDUOIRRHytPmodaP"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"hXVFx1TUe9Rc"},"outputs":[],"source":["import os\n","from copy import deepcopy\n","from dataclasses import dataclass\n","from typing import Dict, List, Optional, Tuple\n","from datasets import load_dataset, set_caching_enabled\n","import numpy as np\n","from PIL import Image\n","import torch\n","import torch.nn as nn\n","from transformers import (\n"," # Preprocessing / Common\n"," AutoTokenizer, AutoFeatureExtractor,\n"," # Text & Image Models (Now, image transformers like ViTModel, DeiTModel, BEiT can also be loaded using AutoModel)\n"," AutoModel,\n"," # Training / Evaluation\n"," TrainingArguments, Trainer,\n"," # Misc\n"," logging\n",")\n","\n","# import nltk\n","# nltk.download('wordnet')\n","from nltk.corpus import wordnet\n","\n","from sklearn.metrics import accuracy_score, f1_score"]},{"cell_type":"code","source":["# SET CACHE FOR HUGGINGFACE TRANSFORMERS + DATASETS\n","os.environ['HF_HOME'] = os.path.join(\".\", \"cache\")\n","# SET ONLY 1 GPU DEVICE\n","os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n","\n","set_caching_enabled(True)\n","logging.set_verbosity_error()"],"metadata":{"id":"0WpGBaJffO9Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print(device)\n","\n","#Additional Info when using cuda\n","if device.type == 'cuda':\n"," print(torch.cuda.get_device_name(0))\n","# print('Memory Usage:')\n","# print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n","# print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')"],"metadata":{"id":"A2es_I4_fQ_F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset = load_dataset(\n"," \"csv\",\n"," data_files={\n"," \"train\": os.path.join(\"..\", \"dataset\", \"data_train.csv\"),\n"," \"test\": os.path.join(\"..\", \"dataset\", \"data_eval.csv\")\n"," }\n",")\n","\n","with open(os.path.join(\"..\", \"dataset\", \"answer_space.txt\")) as f:\n"," answer_space = f.read().splitlines()\n","\n","dataset = dataset.map(\n"," lambda examples: {\n"," 'label': [\n"," answer_space.index(ans.replace(\" \", \"\").split(\",\")[0]) # Select the 1st answer if multiple answers are provided\n"," for ans in examples['answer']\n"," ]\n"," },\n"," batched=True\n",")\n","\n","dataset"],"metadata":{"id":"2q5Rc9wMfTYA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from IPython.display import display\n","\n","def showExample(train=True, id=None):\n"," if train:\n"," data = dataset[\"train\"]\n"," else:\n"," data = dataset[\"test\"]\n"," if id == None:\n"," id = np.random.randint(len(data))\n"," image = Image.open(os.path.join(\"..\", \"dataset\", \"images\", data[id][\"image_id\"] + \".png\"))\n"," display(image)\n","\n"," print(\"Question:\\t\", data[id][\"question\"])\n"," print(\"Answer:\\t\\t\", data[id][\"answer\"], \"(Label: {0})\".format(data[id][\"label\"]))"],"metadata":{"id":"RvMZjudffY-A"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["showExample()"],"metadata":{"id":"IbbXsvzlfch7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["@dataclass\n","class MultimodalCollator:\n"," tokenizer: AutoTokenizer\n"," preprocessor: AutoFeatureExtractor\n","\n"," def tokenize_text(self, texts: List[str]):\n"," encoded_text = self.tokenizer(\n"," text=texts,\n"," padding='longest',\n"," max_length=24,\n"," truncation=True,\n"," return_tensors='pt',\n"," return_token_type_ids=True,\n"," return_attention_mask=True,\n"," )\n"," return {\n"," \"input_ids\": encoded_text['input_ids'].squeeze(),\n"," \"token_type_ids\": encoded_text['token_type_ids'].squeeze(),\n"," \"attention_mask\": encoded_text['attention_mask'].squeeze(),\n"," }\n","\n"," def preprocess_images(self, images: List[str]):\n"," processed_images = self.preprocessor(\n"," images=[Image.open(os.path.join(\"..\", \"dataset\", \"images\", image_id + \".png\")).convert('RGB') for image_id in images],\n"," return_tensors=\"pt\",\n"," )\n"," return {\n"," \"pixel_values\": processed_images['pixel_values'].squeeze(),\n"," }\n","\n"," def __call__(self, raw_batch_dict):\n"," return {\n"," **self.tokenize_text(\n"," raw_batch_dict['question']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['question'] for i in raw_batch_dict]\n"," ),\n"," **self.preprocess_images(\n"," raw_batch_dict['image_id']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['image_id'] for i in raw_batch_dict]\n"," ),\n"," 'labels': torch.tensor(\n"," raw_batch_dict['label']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['label'] for i in raw_batch_dict],\n"," dtype=torch.int64\n"," ),\n"," }"],"metadata":{"id":"arT0QhVCfeXZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class MultimodalVQAModel(nn.Module):\n"," def __init__(\n"," self,\n"," num_labels: int = len(answer_space),\n"," intermediate_dim: int = 512,\n"," pretrained_text_name: str = 'bert-base-uncased',\n"," pretrained_image_name: str = 'google/vit-base-patch16-224-in21k'):\n","\n"," super(MultimodalVQAModel, self).__init__()\n"," self.num_labels = num_labels\n"," self.pretrained_text_name = pretrained_text_name\n"," self.pretrained_image_name = pretrained_image_name\n","\n"," self.text_encoder = AutoModel.from_pretrained(\n"," self.pretrained_text_name,\n"," )\n"," self.image_encoder = AutoModel.from_pretrained(\n"," self.pretrained_image_name,\n"," )\n"," self.fusion = nn.Sequential(\n"," nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, intermediate_dim),\n"," nn.ReLU(),\n"," nn.Dropout(0.5),\n"," )\n","\n"," self.classifier = nn.Linear(intermediate_dim, self.num_labels)\n","\n"," self.criterion = nn.CrossEntropyLoss()\n","\n"," def forward(\n"," self,\n"," input_ids: torch.LongTensor,\n"," pixel_values: torch.FloatTensor,\n"," attention_mask: Optional[torch.LongTensor] = None,\n"," token_type_ids: Optional[torch.LongTensor] = None,\n"," labels: Optional[torch.LongTensor] = None):\n","\n"," encoded_text = self.text_encoder(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask,\n"," token_type_ids=token_type_ids,\n"," return_dict=True,\n"," )\n"," encoded_image = self.image_encoder(\n"," pixel_values=pixel_values,\n"," return_dict=True,\n"," )\n"," fused_output = self.fusion(\n"," torch.cat(\n"," [\n"," encoded_text['pooler_output'],\n"," encoded_image['pooler_output'],\n"," ],\n"," dim=1\n"," )\n"," )\n"," logits = self.classifier(fused_output)\n","\n"," out = {\n"," \"logits\": logits\n"," }\n"," if labels is not None:\n"," loss = self.criterion(logits, labels)\n"," out[\"loss\"] = loss\n","\n"," return out"],"metadata":{"id":"l_IgMBN2fi2y"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def createMultimodalVQACollatorAndModel(text='bert-base-uncased', image='google/vit-base-patch16-224-in21k'):\n"," tokenizer = AutoTokenizer.from_pretrained(text)\n"," preprocessor = AutoFeatureExtractor.from_pretrained(image)\n","\n"," multi_collator = MultimodalCollator(\n"," tokenizer=tokenizer,\n"," preprocessor=preprocessor,\n"," )\n","\n","\n"," multi_model = MultimodalVQAModel(pretrained_text_name=text, pretrained_image_name=image).to(device)\n"," return multi_collator, multi_model"],"metadata":{"id":"4nQFCbszfo2H"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def wup_measure(a,b,similarity_threshold=0.925):\n"," \"\"\"\n"," Returns Wu-Palmer similarity score.\n"," More specifically, it computes:\n"," max_{x \\in interp(a)} max_{y \\in interp(b)} wup(x,y)\n"," where interp is a 'interpretation field'\n"," \"\"\"\n"," def get_semantic_field(a):\n"," weight = 1.0\n"," semantic_field = wordnet.synsets(a,pos=wordnet.NOUN)\n"," return (semantic_field,weight)\n","\n","\n"," def get_stem_word(a):\n"," \"\"\"\n"," Sometimes answer has form word\\d+:wordid.\n"," If so we return word and downweight\n"," \"\"\"\n"," weight = 1.0\n"," return (a,weight)\n","\n","\n"," global_weight=1.0\n","\n"," (a,global_weight_a)=get_stem_word(a)\n"," (b,global_weight_b)=get_stem_word(b)\n"," global_weight = min(global_weight_a,global_weight_b)\n","\n"," if a==b:\n"," # they are the same\n"," return 1.0*global_weight\n","\n"," if a==[] or b==[]:\n"," return 0\n","\n","\n"," interp_a,weight_a = get_semantic_field(a)\n"," interp_b,weight_b = get_semantic_field(b)\n","\n"," if interp_a == [] or interp_b == []:\n"," return 0\n","\n"," # we take the most optimistic interpretation\n"," global_max=0.0\n"," for x in interp_a:\n"," for y in interp_b:\n"," local_score=x.wup_similarity(y)\n"," if local_score > global_max:\n"," global_max=local_score\n","\n"," # we need to use the semantic fields and therefore we downweight\n"," # unless the score is high which indicates both are synonyms\n"," if global_max < similarity_threshold:\n"," interp_weight = 0.1\n"," else:\n"," interp_weight = 1.0\n","\n"," final_score=global_max*weight_a*weight_b*interp_weight*global_weight\n"," return final_score"],"metadata":{"id":"s85RDZTZfqpk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def batch_wup_measure(labels, preds):\n"," wup_scores = [wup_measure(answer_space[label], answer_space[pred]) for label, pred in zip(labels, preds)]\n"," return np.mean(wup_scores)"],"metadata":{"id":"PLq55HN2fv6j"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["labels = np.random.randint(len(answer_space), size=5)\n","preds = np.random.randint(len(answer_space), size=5)\n","\n","def showAnswers(ids):\n"," print([answer_space[id] for id in ids])\n","\n","showAnswers(labels)\n","showAnswers(preds)\n","\n","print(\"Predictions vs Labels: \", batch_wup_measure(labels, preds))\n","print(\"Labels vs Labels: \", batch_wup_measure(labels, labels))"],"metadata":{"id":"qVhNj3mFfyP7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:\n"," logits, labels = eval_tuple\n"," preds = logits.argmax(axis=-1)\n"," return {\n"," \"wups\": batch_wup_measure(labels, preds),\n"," \"acc\": accuracy_score(labels, preds),\n"," \"f1\": f1_score(labels, preds, average='macro')\n"," }"],"metadata":{"id":"TZ-AmQhLf0CA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["args = TrainingArguments(\n"," output_dir=\"checkpoint\",\n"," seed=12345,\n"," evaluation_strategy=\"steps\",\n"," eval_steps=100,\n"," logging_strategy=\"steps\",\n"," logging_steps=100,\n"," save_strategy=\"steps\",\n"," save_steps=100,\n"," save_total_limit=3, # Save only the last 3 checkpoints at any given time while training\n"," metric_for_best_model='wups',\n"," per_device_train_batch_size=32,\n"," per_device_eval_batch_size=32,\n"," remove_unused_columns=False,\n"," num_train_epochs=5,\n"," fp16=True,\n"," # warmup_ratio=0.01,\n"," # learning_rate=5e-4,\n"," # weight_decay=1e-4,\n"," # gradient_accumulation_steps=2,\n"," dataloader_num_workers=8,\n"," load_best_model_at_end=True,\n",")"],"metadata":{"id":"VFzqARref1o1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def createAndTrainModel(dataset, args, text_model='bert-base-uncased', image_model='google/vit-base-patch16-224-in21k', multimodal_model='bert_vit'):\n"," collator, model = createMultimodalVQACollatorAndModel(text_model, image_model)\n","\n"," multi_args = deepcopy(args)\n"," multi_args.output_dir = os.path.join(\"..\", \"checkpoint\", multimodal_model)\n"," multi_trainer = Trainer(\n"," model,\n"," multi_args,\n"," train_dataset=dataset['train'],\n"," eval_dataset=dataset['test'],\n"," data_collator=collator,\n"," compute_metrics=compute_metrics\n"," )\n","\n"," train_multi_metrics = multi_trainer.train()\n"," eval_multi_metrics = multi_trainer.evaluate()\n","\n"," return collator, model, train_multi_metrics, eval_multi_metrics"],"metadata":{"id":"DmZE9CWaf3sr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["collator, model, train_multi_metrics, eval_multi_metrics = createAndTrainModel(dataset, args)"],"metadata":{"id":"F0EgICdAf5oo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["eval_multi_metrics"],"metadata":{"id":"Q1rJI1cBf7yl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model = MultimodalVQAModel()\n","\n","# We use the checkpoint giving best results\n","model.load_state_dict(torch.load(os.path.join(\"..\", \"checkpoint\", \"bert_vit\", \"checkpoint-1500\", \"pytorch_model.bin\")))\n","model.to(device)"],"metadata":{"id":"SZvr0wC8f9ZL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["sample = collator(dataset[\"test\"][2000:2005])\n","\n","input_ids = sample[\"input_ids\"].to(device)\n","token_type_ids = sample[\"token_type_ids\"].to(device)\n","attention_mask = sample[\"attention_mask\"].to(device)\n","pixel_values = sample[\"pixel_values\"].to(device)\n","labels = sample[\"labels\"].to(device)"],"metadata":{"id":"eJT7NPq7f_xF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.eval()\n","output = model(input_ids, pixel_values, attention_mask, token_type_ids, labels)"],"metadata":{"id":"UYRJqDi8gDmH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["preds = output[\"logits\"].argmax(axis=-1).cpu().numpy()\n","preds"],"metadata":{"id":"RRqC6rTpgF9s"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(2000, 2005):\n"," print(\"*********************************************************\")\n"," showExample(train=False, id=i)\n"," print(\"Predicted Answer:\\t\", answer_space[preds[i-2000]])\n"," print(\"*********************************************************\")"],"metadata":{"id":"VSsOz6AygHmo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def countTrainableParameters(model):\n"," num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n"," print(\"No. of trainable parameters:\\t{0:,}\".format(num_params))"],"metadata":{"id":"FxQN0z6mgKEr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["countTrainableParameters(model) # For BERT-ViT model"],"metadata":{"id":"msSSLjWRgMYm"},"execution_count":null,"outputs":[]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPWNySDXDUOIRRHytPmodaP"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"hXVFx1TUe9Rc"},"outputs":[],"source":["import os\n","from copy import deepcopy\n","from dataclasses import dataclass\n","from typing import Dict, List, Optional, Tuple\n","from datasets import load_dataset, set_caching_enabled\n","import numpy as np\n","from PIL import Image\n","import torch\n","import torch.nn as nn\n","from transformers import (\n"," # Preprocessing / Common\n"," AutoTokenizer, AutoFeatureExtractor,\n"," # Text & Image Models (Now, image transformers like ViTModel, DeiTModel, BEiT can also be loaded using AutoModel)\n"," AutoModel,\n"," # Training / Evaluation\n"," TrainingArguments, Trainer,\n"," # Misc\n"," logging\n",")\n","\n","# import nltk\n","# nltk.download('wordnet')\n","from nltk.corpus import wordnet\n","\n","from sklearn.metrics import accuracy_score, f1_score"]},{"cell_type":"code","source":["# SET CACHE FOR HUGGINGFACE TRANSFORMERS + DATASETS\n","os.environ['HF_HOME'] = os.path.join(\".\", \"cache\")\n","# SET ONLY 1 GPU DEVICE\n","os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n","\n","set_caching_enabled(True)\n","logging.set_verbosity_error()"],"metadata":{"id":"0WpGBaJffO9Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print(device)\n","\n","#Additional Info when using cuda\n","if device.type == 'cuda':\n"," print(torch.cuda.get_device_name(0))\n","# print('Memory Usage:')\n","# print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n","# print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')"],"metadata":{"id":"A2es_I4_fQ_F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset = load_dataset(\n"," \"csv\",\n"," data_files={\n"," \"train\": os.path.join(\"..\", \"dataset\", \"data_train.csv\"),\n"," \"test\": os.path.join(\"..\", \"dataset\", \"data_eval.csv\")\n"," }\n",")\n","\n","with open(os.path.join(\"..\", \"dataset\", \"answer_space.txt\")) as f:\n"," answer_space = f.read().splitlines()\n","\n","dataset = dataset.map(\n"," lambda examples: {\n"," 'label': [\n"," answer_space.index(ans.replace(\" \", \"\").split(\",\")[0]) # Select the 1st answer if multiple answers are provided\n"," for ans in examples['answer']\n"," ]\n"," },\n"," batched=True\n",")\n","\n","dataset"],"metadata":{"id":"2q5Rc9wMfTYA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from IPython.display import display\n","\n","def showExample(train=True, id=None):\n"," if train:\n"," data = dataset[\"train\"]\n"," else:\n"," data = dataset[\"test\"]\n"," if id == None:\n"," id = np.random.randint(len(data))\n"," image = Image.open(os.path.join(\"..\", \"dataset\", \"images\", data[id][\"image_id\"] + \".png\"))\n"," display(image)\n","\n"," print(\"Question:\\t\", data[id][\"question\"])\n"," print(\"Answer:\\t\\t\", data[id][\"answer\"], \"(Label: {0})\".format(data[id][\"label\"]))"],"metadata":{"id":"RvMZjudffY-A"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["showExample()"],"metadata":{"id":"IbbXsvzlfch7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["@dataclass\n","class MultimodalCollator:\n"," tokenizer: AutoTokenizer\n"," preprocessor: AutoFeatureExtractor\n","\n"," def tokenize_text(self, texts: List[str]):\n"," encoded_text = self.tokenizer(\n"," text=texts,\n"," padding='longest',\n"," max_length=24,\n"," truncation=True,\n"," return_tensors='pt',\n"," return_token_type_ids=True,\n"," return_attention_mask=True,\n"," )\n"," return {\n"," \"input_ids\": encoded_text['input_ids'].squeeze(),\n"," \"token_type_ids\": encoded_text['token_type_ids'].squeeze(),\n"," \"attention_mask\": encoded_text['attention_mask'].squeeze(),\n"," }\n","\n"," def preprocess_images(self, images: List[str]):\n"," processed_images = self.preprocessor(\n"," images=[Image.open(os.path.join(\"..\", \"dataset\", \"images\", image_id + \".png\")).convert('RGB') for image_id in images],\n"," return_tensors=\"pt\",\n"," )\n"," return {\n"," \"pixel_values\": processed_images['pixel_values'].squeeze(),\n"," }\n","\n"," def __call__(self, raw_batch_dict):\n"," return {\n"," **self.tokenize_text(\n"," raw_batch_dict['question']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['question'] for i in raw_batch_dict]\n"," ),\n"," **self.preprocess_images(\n"," raw_batch_dict['image_id']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['image_id'] for i in raw_batch_dict]\n"," ),\n"," 'labels': torch.tensor(\n"," raw_batch_dict['label']\n"," if isinstance(raw_batch_dict, dict) else\n"," [i['label'] for i in raw_batch_dict],\n"," dtype=torch.int64\n"," ),\n"," }"],"metadata":{"id":"arT0QhVCfeXZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class MultimodalVQAModel(nn.Module):\n"," def __init__(\n"," self,\n"," num_labels: int = len(answer_space),\n"," intermediate_dim: int = 512,\n"," pretrained_text_name: str = 'bert-base-uncased',\n"," pretrained_image_name: str = 'google/vit-base-patch16-224-in21k'):\n","\n"," super(MultimodalVQAModel, self).__init__()\n"," self.num_labels = num_labels\n"," self.pretrained_text_name = pretrained_text_name\n"," self.pretrained_image_name = pretrained_image_name\n","\n"," self.text_encoder = AutoModel.from_pretrained(\n"," self.pretrained_text_name,\n"," )\n"," self.image_encoder = AutoModel.from_pretrained(\n"," self.pretrained_image_name,\n"," )\n"," self.fusion = nn.Sequential(\n"," nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, intermediate_dim),\n"," nn.ReLU(),\n"," nn.Dropout(0.5),\n"," )\n","\n"," self.classifier = nn.Linear(intermediate_dim, self.num_labels)\n","\n"," self.criterion = nn.CrossEntropyLoss()\n","\n"," def forward(\n"," self,\n"," input_ids: torch.LongTensor,\n"," pixel_values: torch.FloatTensor,\n"," attention_mask: Optional[torch.LongTensor] = None,\n"," token_type_ids: Optional[torch.LongTensor] = None,\n"," labels: Optional[torch.LongTensor] = None):\n","\n"," encoded_text = self.text_encoder(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask,\n"," token_type_ids=token_type_ids,\n"," return_dict=True,\n"," )\n"," encoded_image = self.image_encoder(\n"," pixel_values=pixel_values,\n"," return_dict=True,\n"," )\n"," fused_output = self.fusion(\n"," torch.cat(\n"," [\n"," encoded_text['pooler_output'],\n"," encoded_image['pooler_output'],\n"," ],\n"," dim=1\n"," )\n"," )\n"," logits = self.classifier(fused_output)\n","\n"," out = {\n"," \"logits\": logits\n"," }\n"," if labels is not None:\n"," loss = self.criterion(logits, labels)\n"," out[\"loss\"] = loss\n","\n"," return out"],"metadata":{"id":"l_IgMBN2fi2y"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def createMultimodalVQACollatorAndModel(text='bert-base-uncased', image='google/vit-base-patch16-224-in21k'):\n"," tokenizer = AutoTokenizer.from_pretrained(text)\n"," preprocessor = AutoFeatureExtractor.from_pretrained(image)\n","\n"," multi_collator = MultimodalCollator(\n"," tokenizer=tokenizer,\n"," preprocessor=preprocessor,\n"," )\n","\n","\n"," multi_model = MultimodalVQAModel(pretrained_text_name=text, pretrained_image_name=image).to(device)\n"," return multi_collator, multi_model"],"metadata":{"id":"4nQFCbszfo2H"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def wup_measure(a,b,similarity_threshold=0.925):\n"," \"\"\"\n"," Returns Wu-Palmer similarity score.\n"," More specifically, it computes:\n"," max_{x \\in interp(a)} max_{y \\in interp(b)} wup(x,y)\n"," where interp is a 'interpretation field'\n"," \"\"\"\n"," def get_semantic_field(a):\n"," weight = 1.0\n"," semantic_field = wordnet.synsets(a,pos=wordnet.NOUN)\n"," return (semantic_field,weight)\n","\n","\n"," def get_stem_word(a):\n"," \"\"\"\n"," Sometimes answer has form word\\d+:wordid.\n"," If so we return word and downweight\n"," \"\"\"\n"," weight = 1.0\n"," return (a,weight)\n","\n","\n"," global_weight=1.0\n","\n"," (a,global_weight_a)=get_stem_word(a)\n"," (b,global_weight_b)=get_stem_word(b)\n"," global_weight = min(global_weight_a,global_weight_b)\n","\n"," if a==b:\n"," # they are the same\n"," return 1.0*global_weight\n","\n"," if a==[] or b==[]:\n"," return 0\n","\n","\n"," interp_a,weight_a = get_semantic_field(a)\n"," interp_b,weight_b = get_semantic_field(b)\n","\n"," if interp_a == [] or interp_b == []:\n"," return 0\n","\n"," # we take the most optimistic interpretation\n"," global_max=0.0\n"," for x in interp_a:\n"," for y in interp_b:\n"," local_score=x.wup_similarity(y)\n"," if local_score > global_max:\n"," global_max=local_score\n","\n"," # we need to use the semantic fields and therefore we downweight\n"," # unless the score is high which indicates both are synonyms\n"," if global_max < similarity_threshold:\n"," interp_weight = 0.1\n"," else:\n"," interp_weight = 1.0\n","\n"," final_score=global_max*weight_a*weight_b*interp_weight*global_weight\n"," return final_score"],"metadata":{"id":"s85RDZTZfqpk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def batch_wup_measure(labels, preds):\n"," wup_scores = [wup_measure(answer_space[label], answer_space[pred]) for label, pred in zip(labels, preds)]\n"," return np.mean(wup_scores)"],"metadata":{"id":"PLq55HN2fv6j"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["labels = np.random.randint(len(answer_space), size=5)\n","preds = np.random.randint(len(answer_space), size=5)\n","\n","def showAnswers(ids):\n"," print([answer_space[id] for id in ids])\n","\n","showAnswers(labels)\n","showAnswers(preds)\n","\n","print(\"Predictions vs Labels: \", batch_wup_measure(labels, preds))\n","print(\"Labels vs Labels: \", batch_wup_measure(labels, labels))"],"metadata":{"id":"qVhNj3mFfyP7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:\n"," logits, labels = eval_tuple\n"," preds = logits.argmax(axis=-1)\n"," return {\n"," \"wups\": batch_wup_measure(labels, preds),\n"," \"acc\": accuracy_score(labels, preds),\n"," \"f1\": f1_score(labels, preds, average='macro')\n"," }"],"metadata":{"id":"TZ-AmQhLf0CA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["args = TrainingArguments(\n"," output_dir=\"checkpoint\",\n"," seed=12345,\n"," evaluation_strategy=\"steps\",\n"," eval_steps=100,\n"," logging_strategy=\"steps\",\n"," logging_steps=100,\n"," save_strategy=\"steps\",\n"," save_steps=100,\n"," save_total_limit=3, # Save only the last 3 checkpoints at any given time while training\n"," metric_for_best_model='wups',\n"," per_device_train_batch_size=32,\n"," per_device_eval_batch_size=32,\n"," remove_unused_columns=False,\n"," num_train_epochs=5,\n"," fp16=True,\n"," # warmup_ratio=0.01,\n"," # learning_rate=5e-4,\n"," # weight_decay=1e-4,\n"," # gradient_accumulation_steps=2,\n"," dataloader_num_workers=8,\n"," load_best_model_at_end=True,\n",")"],"metadata":{"id":"VFzqARref1o1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def createAndTrainModel(dataset, args, text_model='bert-base-uncased', image_model='google/vit-base-patch16-224-in21k', multimodal_model='bert_vit'):\n"," collator, model = createMultimodalVQACollatorAndModel(text_model, image_model)\n","\n"," multi_args = deepcopy(args)\n"," multi_args.output_dir = os.path.join(\"..\", \"checkpoint\", multimodal_model)\n"," multi_trainer = Trainer(\n"," model,\n"," multi_args,\n"," train_dataset=dataset['train'],\n"," eval_dataset=dataset['test'],\n"," data_collator=collator,\n"," compute_metrics=compute_metrics\n"," )\n","\n"," train_multi_metrics = multi_trainer.train()\n"," eval_multi_metrics = multi_trainer.evaluate()\n","\n"," return collator, model, train_multi_metrics, eval_multi_metrics"],"metadata":{"id":"DmZE9CWaf3sr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["collator, model, train_multi_metrics, eval_multi_metrics = createAndTrainModel(dataset, args)"],"metadata":{"id":"F0EgICdAf5oo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["eval_multi_metrics"],"metadata":{"id":"Q1rJI1cBf7yl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model = MultimodalVQAModel()\n","\n","# We use the checkpoint giving best results\n","model.load_state_dict(torch.load(os.path.join(\"..\", \"checkpoint\", \"bert_vit\", \"checkpoint-1500\", \"pytorch_model.bin\")))\n","model.to(device)"],"metadata":{"id":"SZvr0wC8f9ZL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["sample = collator(dataset[\"test\"][2000:2005])\n","\n","input_ids = sample[\"input_ids\"].to(device)\n","token_type_ids = sample[\"token_type_ids\"].to(device)\n","attention_mask = sample[\"attention_mask\"].to(device)\n","pixel_values = sample[\"pixel_values\"].to(device)\n","labels = sample[\"labels\"].to(device)"],"metadata":{"id":"eJT7NPq7f_xF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.eval()\n","output = model(input_ids, pixel_values, attention_mask, token_type_ids, labels)"],"metadata":{"id":"UYRJqDi8gDmH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["preds = output[\"logits\"].argmax(axis=-1).cpu().numpy()\n","preds"],"metadata":{"id":"RRqC6rTpgF9s"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(2000, 2005):\n"," print(\"*********************************************************\")\n"," showExample(train=False, id=i)\n"," print(\"Predicted Answer:\\t\", answer_space[preds[i-2000]])\n"," print(\"*********************************************************\")"],"metadata":{"id":"VSsOz6AygHmo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def countTrainableParameters(model):\n"," num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n"," print(\"No. of trainable parameters:\\t{0:,}\".format(num_params))"],"metadata":{"id":"FxQN0z6mgKEr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["countTrainableParameters(model) # For BERT-ViT model"],"metadata":{"id":"msSSLjWRgMYm"},"execution_count":null,"outputs":[]}]}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment