| { |
| "cells": [ |
| { |
| "cell_type": "raw", |
| "metadata": {}, |
| "source": [ |
| "#\n", |
| "# Licensed to the Apache Software Foundation (ASF) under one or more\n", |
| "# contributor license agreements. See the NOTICE file distributed with\n", |
| "# this work for additional information regarding copyright ownership.\n", |
| "# The ASF licenses this file to You under the Apache License, Version 2.0\n", |
| "# (the \"License\"); you may not use this file except in compliance with\n", |
| "# the License. You may obtain a copy of the License at\n", |
| "#\n", |
| "# http://www.apache.org/licenses/LICENSE-2.0\n", |
| "#\n", |
| "# Unless required by applicable law or agreed to in writing, software\n", |
| "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", |
| "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", |
| "# See the License for the specific language governing permissions and\n", |
| "# limitations under the License.\n", |
| "#" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import sys\n", |
| "import os\n", |
| "nb_dir = os.path.split(os.getcwd())[0]\n", |
| "sys.path.append(os.getcwd() + \"/../\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import logging\n", |
| "logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)\n", |
| "logging.getLogger(\"bertft\").setLevel(logging.DEBUG)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import bertft\n", |
| "from bertft import lget\n", |
| "import matplotlib.pyplot as plt\n", |
| "import pandas as pd" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Important: auto-reload of bertft module\n", |
| "%load_ext autoreload\n", |
| "%autoreload 2" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def mk_graph(x1):\n", |
| " x1 = list(filter(lambda x: -2 < x < 0.99, x1))[:40]\n", |
| " kwargs = dict(alpha=0.3, bins=20)\n", |
| "\n", |
| " plt.hist(x1, **kwargs, color='g', label='FastText score')\n", |
| " plt.gca().set(title='Top 40 masks histogram of embeddings score', ylabel='Count')\n", |
| "\n", |
| " plt.legend()\n", |
| " plt.show()\n", |
| "\n", |
| "\n", |
| "def mk_graph2(x1):\n", |
| " kwargs = dict(alpha=1, bins=50)\n", |
| "\n", |
| " plt.hist(x1, **kwargs, color='r', label='Weighted score')\n", |
| " plt.gca().set(\n", |
| " title='Distribution of weighted score of top 200 unfiltered results (Target excluded)',\n", |
| " ylabel='Count'\n", |
| " )\n", |
| "\n", |
| " plt.legend()\n", |
| " plt.show()\n", |
| "\n", |
| "\n", |
| "def on_run(self, kunfiltered, unfiltered, filtered_top, target, tokenizer, top_tokens):\n", |
| " print(\"Unfiltered top:\")\n", |
| "\n", |
| " print(pd.DataFrame({\n", |
| " 'word': lget(kunfiltered, 0),\n", |
| " 'bert': self.dget(kunfiltered, 1),\n", |
| " 'normalized': self.dget(kunfiltered, 2),\n", |
| " 'ftext': self.dget(kunfiltered, 3),\n", |
| " 'ftext-sentence': self.dget(kunfiltered, 4),\n", |
| " 'score': lget(kunfiltered, 5),\n", |
| " }))\n", |
| "\n", |
| " print(\"Filtered top:\")\n", |
| "\n", |
| " print(filtered_top)\n", |
| "\n", |
| " mk_graph(lget(unfiltered, 2)[:100])\n", |
| " mk_graph2(lget(list(filter(lambda x: x[0] != target, unfiltered)), 4))\n", |
| "\n", |
| " if target is not None:\n", |
| " vec = tokenizer.encode(target, return_tensors=\"pt\")[0]\n", |
| " if len(vec) == 3:\n", |
| " tk = vec[1].item()\n", |
| " pos = None\n", |
| " score = None\n", |
| "\n", |
| " for e, (t, v) in enumerate(top_tokens):\n", |
| " if t == tk:\n", |
| " score = v\n", |
| " break\n", |
| " print(\"Original word position: %s; score: %s \" % (pos, score))\n", |
| " else:\n", |
| " if len(vec) > 3:\n", |
| " print(\"Original word is more then 1 token\")\n", |
| " print(tokenizer.tokenize(target))\n", |
| " else:\n", |
| " print(\"Original word wasn't found\")\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "pipeline = bertft.Pipeline()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Example of usage\n", |
| "res = pipeline.find_top(\n", |
| " # List of sentences with target word position\n", |
| " [\n", |
| " (\"what is the local weather forecast?\", 3, 4),\n", |
| " (\"what is chances of rain tomorrow?\", 4, 2),\n", |
| " (\"is driving a car faster then taking a bus?\", 3),\n", |
| " (\"who is the best football player of all time?\", 4)\n", |
| " ],\n", |
| " k = 20, # Filter best k results (by weighted score)\n", |
| " top_bert = 100, # Number of initial filter of bert output \n", |
| " min_ftext = 0.3, # Minimal required score of fast text \n", |
| " min_bert = 0.5, # Minimal required score of Bert \n", |
| " weights = [ # Weights of models scores to calculate total weighted score\n", |
| " 1, # bert\n", |
| " 1, # fast text\n", |
| " ],\n", |
| " min_score = 0 # Minimum required score\n", |
| ")\n", |
| "print(res)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.8.3" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 2 |
| } |