blob: 78b91f1a53d3870cbe270069c8d7f240c2a73761 [file] [log] [blame]
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"#\n",
"# Licensed to the Apache Software Foundation (ASF) under one or more\n",
"# contributor license agreements. See the NOTICE file distributed with\n",
"# this work for additional information regarding copyright ownership.\n",
"# The ASF licenses this file to You under the Apache License, Version 2.0\n",
"# (the \"License\"); you may not use this file except in compliance with\n",
"# the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"#"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"nb_dir = os.path.split(os.getcwd())[0]\n",
"sys.path.append(os.getcwd() + \"/../\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)\n",
"logging.getLogger(\"bertft\").setLevel(logging.DEBUG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import bertft\n",
"from bertft import lget\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Important: auto-reload of bertft module\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mk_graph(x1):\n",
" x1 = list(filter(lambda x: -2 < x < 0.99, x1))[:40]\n",
" kwargs = dict(alpha=0.3, bins=20)\n",
"\n",
" plt.hist(x1, **kwargs, color='g', label='FastText score')\n",
" plt.gca().set(title='Top 40 masks histogram of embeddings score', ylabel='Count')\n",
"\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"\n",
"def mk_graph2(x1):\n",
" kwargs = dict(alpha=1, bins=50)\n",
"\n",
" plt.hist(x1, **kwargs, color='r', label='Weighted score')\n",
" plt.gca().set(\n",
" title='Distribution of weighted score of top 200 unfiltered results (Target excluded)',\n",
" ylabel='Count'\n",
" )\n",
"\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"\n",
"def on_run(self, kunfiltered, unfiltered, filtered_top, target, tokenizer, top_tokens):\n",
" print(\"Unfiltered top:\")\n",
"\n",
" print(pd.DataFrame({\n",
" 'word': lget(kunfiltered, 0),\n",
" 'bert': self.dget(kunfiltered, 1),\n",
" 'normalized': self.dget(kunfiltered, 2),\n",
" 'ftext': self.dget(kunfiltered, 3),\n",
" 'ftext-sentence': self.dget(kunfiltered, 4),\n",
" 'score': lget(kunfiltered, 5),\n",
" }))\n",
"\n",
" print(\"Filtered top:\")\n",
"\n",
" print(filtered_top)\n",
"\n",
" mk_graph(lget(unfiltered, 2)[:100])\n",
" mk_graph2(lget(list(filter(lambda x: x[0] != target, unfiltered)), 4))\n",
"\n",
" if target is not None:\n",
" vec = tokenizer.encode(target, return_tensors=\"pt\")[0]\n",
" if len(vec) == 3:\n",
" tk = vec[1].item()\n",
" pos = None\n",
" score = None\n",
"\n",
" for e, (t, v) in enumerate(top_tokens):\n",
" if t == tk:\n",
" score = v\n",
" break\n",
" print(\"Original word position: %s; score: %s \" % (pos, score))\n",
" else:\n",
" if len(vec) > 3:\n",
" print(\"Original word is more then 1 token\")\n",
" print(tokenizer.tokenize(target))\n",
" else:\n",
" print(\"Original word wasn't found\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipeline = bertft.Pipeline()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example of usage\n",
"res = pipeline.find_top(\n",
" # List of sentences with target word position\n",
" [\n",
" (\"what is the local weather forecast?\", 3, 4),\n",
" (\"what is chances of rain tomorrow?\", 4, 2),\n",
" (\"is driving a car faster then taking a bus?\", 3),\n",
" (\"who is the best football player of all time?\", 4)\n",
" ],\n",
" k = 20, # Filter best k results (by weighted score)\n",
" top_bert = 100, # Number of initial filter of bert output \n",
" min_ftext = 0.3, # Minimal required score of fast text \n",
" min_bert = 0.5, # Minimal required score of Bert \n",
" weights = [ # Weights of models scores to calculate total weighted score\n",
" 1, # bert\n",
" 1, # fast text\n",
" ],\n",
" min_score = 0 # Minimum required score\n",
")\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}