blob: 6037a6ca5b019d2c5857bacbfcef928185cb1d72 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Source\n",
"Notebook obtained from https://github.com/TeamHG-Memex/sklearn-crfsuite/blob/master/docs/CoNLL2002.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"plt.style.use('ggplot')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from itertools import chain\n",
"\n",
"import nltk\n",
"import sklearn\n",
"import scipy.stats\n",
"from sklearn.metrics import make_scorer\n",
"from sklearn.cross_validation import cross_val_score\n",
"from sklearn.grid_search import RandomizedSearchCV\n",
"\n",
"import sklearn_crfsuite\n",
"from sklearn_crfsuite import scorers\n",
"from sklearn_crfsuite import metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's use CoNLL 2002 data to build a NER system\n",
"\n",
"CoNLL2002 corpus is available in NLTK. We use Spanish data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"nltk.download(info_or_id='conll2002', download_dir=os.environ[\"MARVIN_DATA_PATH\"])\n",
"nltk.corpus.conll2002.fileids()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n",
"test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_sents[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Features\n",
"\n",
"Next, define some features. In this example we use word identity, word suffix, word shape and word POS tag; also, some information from nearby words is used. \n",
"\n",
"This makes a simple baseline, but you certainly can add and remove some features to get (much?) better results - experiment with it.\n",
"\n",
"sklearn-crfsuite (and python-crfsuite) supports several feature formats; here we use feature dicts."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def word2features(sent, i):\n",
" word = sent[i][0]\n",
" postag = sent[i][1]\n",
" \n",
" features = {\n",
" 'bias': 1.0,\n",
" 'word.lower()': word.lower(),\n",
" 'word[-3:]': word[-3:],\n",
" 'word[-2:]': word[-2:],\n",
" 'word.isupper()': word.isupper(),\n",
" 'word.istitle()': word.istitle(),\n",
" 'word.isdigit()': word.isdigit(),\n",
" 'postag': postag,\n",
" 'postag[:2]': postag[:2], \n",
" }\n",
" if i > 0:\n",
" word1 = sent[i-1][0]\n",
" postag1 = sent[i-1][1]\n",
" features.update({\n",
" '-1:word.lower()': word1.lower(),\n",
" '-1:word.istitle()': word1.istitle(),\n",
" '-1:word.isupper()': word1.isupper(),\n",
" '-1:postag': postag1,\n",
" '-1:postag[:2]': postag1[:2],\n",
" })\n",
" else:\n",
" features['BOS'] = True\n",
" \n",
" if i < len(sent)-1:\n",
" word1 = sent[i+1][0]\n",
" postag1 = sent[i+1][1]\n",
" features.update({\n",
" '+1:word.lower()': word1.lower(),\n",
" '+1:word.istitle()': word1.istitle(),\n",
" '+1:word.isupper()': word1.isupper(),\n",
" '+1:postag': postag1,\n",
" '+1:postag[:2]': postag1[:2],\n",
" })\n",
" else:\n",
" features['EOS'] = True\n",
" \n",
" return features\n",
"\n",
"\n",
"def sent2features(sent):\n",
" return [word2features(sent, i) for i in range(len(sent))]\n",
"\n",
"def sent2labels(sent):\n",
" return [label for token, postag, label in sent]\n",
"\n",
"def sent2tokens(sent):\n",
" return [token for token, postag, label in sent]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is what word2features extracts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sent2features(train_sents[0])[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Extract features from the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"X_train = [sent2features(s) for s in train_sents]\n",
"y_train = [sent2labels(s) for s in train_sents]\n",
"\n",
"X_test = [sent2features(s) for s in test_sents]\n",
"y_test = [sent2labels(s) for s in test_sents]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"To see all possible CRF parameters check its docstring. Here we are useing L-BFGS training algorithm (it is default) with Elastic Net (L1 + L2) regularization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"crf = sklearn_crfsuite.CRF(\n",
" algorithm='lbfgs', \n",
" c1=0.1, \n",
" c2=0.1, \n",
" max_iterations=100, \n",
" all_possible_transitions=True\n",
")\n",
"crf.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation\n",
"\n",
"There is much more O entities in data set, but we're more interested in other entities. To account for this we'll use averaged F1 score computed for all labels except for O. ``sklearn-crfsuite.metrics`` package provides some useful metrics for sequence classification task, including this one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"labels = list(crf.classes_)\n",
"labels.remove('O')\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = crf.predict(X_test)\n",
"metrics.flat_f1_score(y_test, y_pred, \n",
" average='weighted', labels=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect per-class results in more detail:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# group B and I results\n",
"sorted_labels = sorted(\n",
" labels, \n",
" key=lambda name: (name[1:], name[0])\n",
")\n",
"print(metrics.flat_classification_report(\n",
" y_test, y_pred, labels=sorted_labels, digits=3\n",
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter Optimization\n",
"\n",
"To improve quality try to select regularization parameters using randomized search and 3-fold cross-validation.\n",
"\n",
"I takes quite a lot of CPU time and RAM (we're fitting a model ``50 * 3 = 150`` times), so grab a tea and be patient, or reduce n_iter in RandomizedSearchCV, or fit model only on a subset of training data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"# define fixed parameters and parameters to search\n",
"crf = sklearn_crfsuite.CRF(\n",
" algorithm='lbfgs', \n",
" max_iterations=100, \n",
" all_possible_transitions=True\n",
")\n",
"params_space = {\n",
" 'c1': scipy.stats.expon(scale=0.5),\n",
" 'c2': scipy.stats.expon(scale=0.05),\n",
"}\n",
"\n",
"# use the same metric for evaluation\n",
"f1_scorer = make_scorer(metrics.flat_f1_score, \n",
" average='weighted', labels=labels)\n",
"\n",
"# search\n",
"rs = RandomizedSearchCV(crf, params_space, \n",
" cv=3, \n",
" verbose=1, \n",
" n_jobs=-1, \n",
" n_iter=50, \n",
" scoring=f1_scorer)\n",
"rs.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Best result:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# crf = rs.best_estimator_\n",
"print('best params:', rs.best_params_)\n",
"print('best CV score:', rs.best_score_)\n",
"print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check parameter space\n",
"\n",
"A chart which shows which ``c1`` and ``c2`` values have RandomizedSearchCV checked. Red color means better results, blue means worse."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_x = [s.parameters['c1'] for s in rs.grid_scores_]\n",
"_y = [s.parameters['c2'] for s in rs.grid_scores_]\n",
"_c = [s.mean_validation_score for s in rs.grid_scores_]\n",
"\n",
"fig = plt.figure()\n",
"fig.set_size_inches(12, 12)\n",
"ax = plt.gca()\n",
"ax.set_yscale('log')\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('C1')\n",
"ax.set_ylabel('C2')\n",
"ax.set_title(\"Randomized Hyperparameter Search CV Results (min={:0.3}, max={:0.3})\".format(\n",
" min(_c), max(_c)\n",
"))\n",
"\n",
"ax.scatter(_x, _y, c=_c, s=60, alpha=0.9, edgecolors=[0,0,0])\n",
"\n",
"print(\"Dark blue => {:0.4}, dark red => {:0.4}\".format(min(_c), max(_c)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check best estimator on our test data\n",
"\n",
"As you can see, quality is improved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"crf = rs.best_estimator_\n",
"y_pred = crf.predict(X_test)\n",
"print(metrics.flat_classification_report(\n",
" y_test, y_pred, labels=sorted_labels, digits=3\n",
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's check what classifier learned"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"\n",
"def print_transitions(trans_features):\n",
" for (label_from, label_to), weight in trans_features:\n",
" print(\"%-6s -> %-7s %0.6f\" % (label_from, label_to, weight))\n",
"\n",
"print(\"Top likely transitions:\")\n",
"print_transitions(Counter(crf.transition_features_).most_common(20))\n",
"\n",
"print(\"\\nTop unlikely transitions:\")\n",
"print_transitions(Counter(crf.transition_features_).most_common()[-20:])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that, for example, it is very likely that the beginning of an organization name (B-ORG) will be followed by a token inside organization name (I-ORG), but transitions to I-ORG from tokens with other labels are penalized.\n",
"\n",
"Check the state features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def print_state_features(state_features):\n",
" for (attr, label), weight in state_features:\n",
" print(\"%0.6f %-8s %s\" % (weight, label, attr)) \n",
"\n",
"print(\"Top positive:\")\n",
"print_state_features(Counter(crf.state_features_).most_common(30))\n",
"\n",
"print(\"\\nTop negative:\")\n",
"print_state_features(Counter(crf.state_features_).most_common()[-30:])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"Some observations:\n",
"\n",
" * **9.385823 B-ORG word.lower():psoe-progresistas** - the model remembered names of some entities - maybe it is overfit, or maybe our features are not adequate, or maybe remembering is indeed helpful;\n",
" * **4.636151 I-LOC -1:word.lower():calle:** \"calle\" is a street in Spanish; model learns that if a previous word was \"calle\" then the token is likely a part of location;\n",
" * **-5.632036 O word.isupper()**, **-8.215073 O word.istitle()** : UPPERCASED or TitleCased words are likely entities of some kind;\n",
" * **-2.097561 O postag:NP** - proper nouns (NP is a proper noun in the Spanish tagset) are often entities.\n",
"\n",
"What to do next\n",
"\n",
" * Load 'testa' Spanish data.\n",
" * Use it to develop better features and to find best model parameters.\n",
" * Apply the model to 'testb' data again.\n",
"\n",
"The model in this notebook is just a starting point; you certainly can do better!\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}