| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# Source\n", |
| "Notebook obtained from https://github.com/TeamHG-Memex/sklearn-crfsuite/blob/master/docs/CoNLL2002.ipynb" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%matplotlib inline\n", |
| "import matplotlib.pyplot as plt\n", |
| "plt.style.use('ggplot')" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from itertools import chain\n", |
| "\n", |
| "import nltk\n", |
| "import sklearn\n", |
| "import scipy.stats\n", |
| "from sklearn.metrics import make_scorer\n", |
| "from sklearn.cross_validation import cross_val_score\n", |
| "from sklearn.grid_search import RandomizedSearchCV\n", |
| "\n", |
| "import sklearn_crfsuite\n", |
| "from sklearn_crfsuite import scorers\n", |
| "from sklearn_crfsuite import metrics" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Let's use CoNLL 2002 data to build a NER system\n", |
| "\n", |
| "CoNLL2002 corpus is available in NLTK. We use Spanish data." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import os\n", |
| "nltk.download(info_or_id='conll2002', download_dir=os.environ[\"MARVIN_DATA_PATH\"])\n", |
| "nltk.corpus.conll2002.fileids()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%time\n", |
| "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n", |
| "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "train_sents[0]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Features\n", |
| "\n", |
| "Next, define some features. In this example we use word identity, word suffix, word shape and word POS tag; also, some information from nearby words is used. \n", |
| "\n", |
| "This makes a simple baseline, but you certainly can add and remove some features to get (much?) better results - experiment with it.\n", |
| "\n", |
| "sklearn-crfsuite (and python-crfsuite) supports several feature formats; here we use feature dicts." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def word2features(sent, i):\n", |
| " word = sent[i][0]\n", |
| " postag = sent[i][1]\n", |
| " \n", |
| " features = {\n", |
| " 'bias': 1.0,\n", |
| " 'word.lower()': word.lower(),\n", |
| " 'word[-3:]': word[-3:],\n", |
| " 'word[-2:]': word[-2:],\n", |
| " 'word.isupper()': word.isupper(),\n", |
| " 'word.istitle()': word.istitle(),\n", |
| " 'word.isdigit()': word.isdigit(),\n", |
| " 'postag': postag,\n", |
| " 'postag[:2]': postag[:2], \n", |
| " }\n", |
| " if i > 0:\n", |
| " word1 = sent[i-1][0]\n", |
| " postag1 = sent[i-1][1]\n", |
| " features.update({\n", |
| " '-1:word.lower()': word1.lower(),\n", |
| " '-1:word.istitle()': word1.istitle(),\n", |
| " '-1:word.isupper()': word1.isupper(),\n", |
| " '-1:postag': postag1,\n", |
| " '-1:postag[:2]': postag1[:2],\n", |
| " })\n", |
| " else:\n", |
| " features['BOS'] = True\n", |
| " \n", |
| " if i < len(sent)-1:\n", |
| " word1 = sent[i+1][0]\n", |
| " postag1 = sent[i+1][1]\n", |
| " features.update({\n", |
| " '+1:word.lower()': word1.lower(),\n", |
| " '+1:word.istitle()': word1.istitle(),\n", |
| " '+1:word.isupper()': word1.isupper(),\n", |
| " '+1:postag': postag1,\n", |
| " '+1:postag[:2]': postag1[:2],\n", |
| " })\n", |
| " else:\n", |
| " features['EOS'] = True\n", |
| " \n", |
| " return features\n", |
| "\n", |
| "\n", |
| "def sent2features(sent):\n", |
| " return [word2features(sent, i) for i in range(len(sent))]\n", |
| "\n", |
| "def sent2labels(sent):\n", |
| " return [label for token, postag, label in sent]\n", |
| "\n", |
| "def sent2tokens(sent):\n", |
| " return [token for token, postag, label in sent]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "This is what word2features extracts:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sent2features(train_sents[0])[0]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Extract features from the data:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%time\n", |
| "X_train = [sent2features(s) for s in train_sents]\n", |
| "y_train = [sent2labels(s) for s in train_sents]\n", |
| "\n", |
| "X_test = [sent2features(s) for s in test_sents]\n", |
| "y_test = [sent2labels(s) for s in test_sents]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Training\n", |
| "\n", |
| "To see all possible CRF parameters check its docstring. Here we are useing L-BFGS training algorithm (it is default) with Elastic Net (L1 + L2) regularization." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%time\n", |
| "crf = sklearn_crfsuite.CRF(\n", |
| " algorithm='lbfgs', \n", |
| " c1=0.1, \n", |
| " c2=0.1, \n", |
| " max_iterations=100, \n", |
| " all_possible_transitions=True\n", |
| ")\n", |
| "crf.fit(X_train, y_train)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Evaluation\n", |
| "\n", |
| "There is much more O entities in data set, but we're more interested in other entities. To account for this we'll use averaged F1 score computed for all labels except for O. ``sklearn-crfsuite.metrics`` package provides some useful metrics for sequence classification task, including this one." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "labels = list(crf.classes_)\n", |
| "labels.remove('O')\n", |
| "labels" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "y_pred = crf.predict(X_test)\n", |
| "metrics.flat_f1_score(y_test, y_pred, \n", |
| " average='weighted', labels=labels)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Inspect per-class results in more detail:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# group B and I results\n", |
| "sorted_labels = sorted(\n", |
| " labels, \n", |
| " key=lambda name: (name[1:], name[0])\n", |
| ")\n", |
| "print(metrics.flat_classification_report(\n", |
| " y_test, y_pred, labels=sorted_labels, digits=3\n", |
| "))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Hyperparameter Optimization\n", |
| "\n", |
| "To improve quality try to select regularization parameters using randomized search and 3-fold cross-validation.\n", |
| "\n", |
| "I takes quite a lot of CPU time and RAM (we're fitting a model ``50 * 3 = 150`` times), so grab a tea and be patient, or reduce n_iter in RandomizedSearchCV, or fit model only on a subset of training data." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%time\n", |
| "# define fixed parameters and parameters to search\n", |
| "crf = sklearn_crfsuite.CRF(\n", |
| " algorithm='lbfgs', \n", |
| " max_iterations=100, \n", |
| " all_possible_transitions=True\n", |
| ")\n", |
| "params_space = {\n", |
| " 'c1': scipy.stats.expon(scale=0.5),\n", |
| " 'c2': scipy.stats.expon(scale=0.05),\n", |
| "}\n", |
| "\n", |
| "# use the same metric for evaluation\n", |
| "f1_scorer = make_scorer(metrics.flat_f1_score, \n", |
| " average='weighted', labels=labels)\n", |
| "\n", |
| "# search\n", |
| "rs = RandomizedSearchCV(crf, params_space, \n", |
| " cv=3, \n", |
| " verbose=1, \n", |
| " n_jobs=-1, \n", |
| " n_iter=50, \n", |
| " scoring=f1_scorer)\n", |
| "rs.fit(X_train, y_train)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Best result:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# crf = rs.best_estimator_\n", |
| "print('best params:', rs.best_params_)\n", |
| "print('best CV score:', rs.best_score_)\n", |
| "print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "### Check parameter space\n", |
| "\n", |
| "A chart which shows which ``c1`` and ``c2`` values have RandomizedSearchCV checked. Red color means better results, blue means worse." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "_x = [s.parameters['c1'] for s in rs.grid_scores_]\n", |
| "_y = [s.parameters['c2'] for s in rs.grid_scores_]\n", |
| "_c = [s.mean_validation_score for s in rs.grid_scores_]\n", |
| "\n", |
| "fig = plt.figure()\n", |
| "fig.set_size_inches(12, 12)\n", |
| "ax = plt.gca()\n", |
| "ax.set_yscale('log')\n", |
| "ax.set_xscale('log')\n", |
| "ax.set_xlabel('C1')\n", |
| "ax.set_ylabel('C2')\n", |
| "ax.set_title(\"Randomized Hyperparameter Search CV Results (min={:0.3}, max={:0.3})\".format(\n", |
| " min(_c), max(_c)\n", |
| "))\n", |
| "\n", |
| "ax.scatter(_x, _y, c=_c, s=60, alpha=0.9, edgecolors=[0,0,0])\n", |
| "\n", |
| "print(\"Dark blue => {:0.4}, dark red => {:0.4}\".format(min(_c), max(_c)))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Check best estimator on our test data\n", |
| "\n", |
| "As you can see, quality is improved." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "crf = rs.best_estimator_\n", |
| "y_pred = crf.predict(X_test)\n", |
| "print(metrics.flat_classification_report(\n", |
| " y_test, y_pred, labels=sorted_labels, digits=3\n", |
| "))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## Let's check what classifier learned" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from collections import Counter\n", |
| "\n", |
| "def print_transitions(trans_features):\n", |
| " for (label_from, label_to), weight in trans_features:\n", |
| " print(\"%-6s -> %-7s %0.6f\" % (label_from, label_to, weight))\n", |
| "\n", |
| "print(\"Top likely transitions:\")\n", |
| "print_transitions(Counter(crf.transition_features_).most_common(20))\n", |
| "\n", |
| "print(\"\\nTop unlikely transitions:\")\n", |
| "print_transitions(Counter(crf.transition_features_).most_common()[-20:])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "We can see that, for example, it is very likely that the beginning of an organization name (B-ORG) will be followed by a token inside organization name (I-ORG), but transitions to I-ORG from tokens with other labels are penalized.\n", |
| "\n", |
| "Check the state features:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def print_state_features(state_features):\n", |
| " for (attr, label), weight in state_features:\n", |
| " print(\"%0.6f %-8s %s\" % (weight, label, attr)) \n", |
| "\n", |
| "print(\"Top positive:\")\n", |
| "print_state_features(Counter(crf.state_features_).most_common(30))\n", |
| "\n", |
| "print(\"\\nTop negative:\")\n", |
| "print_state_features(Counter(crf.state_features_).most_common()[-30:])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "\n", |
| "\n", |
| "Some observations:\n", |
| "\n", |
| " * **9.385823 B-ORG word.lower():psoe-progresistas** - the model remembered names of some entities - maybe it is overfit, or maybe our features are not adequate, or maybe remembering is indeed helpful;\n", |
| " * **4.636151 I-LOC -1:word.lower():calle:** \"calle\" is a street in Spanish; model learns that if a previous word was \"calle\" then the token is likely a part of location;\n", |
| " * **-5.632036 O word.isupper()**, **-8.215073 O word.istitle()** : UPPERCASED or TitleCased words are likely entities of some kind;\n", |
| " * **-2.097561 O postag:NP** - proper nouns (NP is a proper noun in the Spanish tagset) are often entities.\n", |
| "\n", |
| "What to do next\n", |
| "\n", |
| " * Load 'testa' Spanish data.\n", |
| " * Use it to develop better features and to find best model parameters.\n", |
| " * Apply the model to 'testb' data again.\n", |
| "\n", |
| "The model in this notebook is just a starting point; you certainly can do better!\n", |
| "\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 2", |
| "language": "python", |
| "name": "python2" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 2 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython2", |
| "version": "2.7.12" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 1 |
| } |