blob: ef5f64fb1af3e573071651f207565f575cd7dd7a [file] [log] [blame]
# !/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# -*- coding: utf-8 -*-
import mxnet as mx
import numpy as np
import pickle
def load_obj(name):
with open(name + '.pkl', 'rb') as f:
return pickle.load(f)
tag_dict = load_obj("../data/tag_to_index")
not_entity_index = tag_dict["O"]
def classifer_metrics(label, pred):
"""
computes f1, precision and recall on the entity class
"""
prediction = np.argmax(pred, axis=1)
label = label.astype(int)
pred_is_entity = prediction != not_entity_index
label_is_entity = label != not_entity_index
corr_pred = (prediction == label) == (pred_is_entity == True)
#how many entities are there?
num_entities = np.sum(label_is_entity)
entity_preds = np.sum(pred_is_entity)
#how many times did we correctly predict an entity?
correct_entitites = np.sum(corr_pred[pred_is_entity])
#precision: when we predict entity, how often are we right?
if entity_preds == 0:
precision = np.nan
else:
precision = correct_entitites/entity_preds
#recall: of the things that were an entity, how many did we catch?
recall = correct_entitites / num_entities
if num_entities == 0:
recall = np.nan
# To prevent dozens of warning: RuntimeWarning: divide by zero encountered in long_scalars
if precision + recall == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1
def entity_precision(label, pred):
return classifer_metrics(label, pred)[0]
def entity_recall(label, pred):
return classifer_metrics(label, pred)[1]
def entity_f1(label, pred):
return classifer_metrics(label, pred)[2]
def composite_classifier_metrics():
metric1 = mx.metric.CustomMetric(feval=entity_precision, name='entity precision')
metric2 = mx.metric.CustomMetric(feval=entity_recall, name='entity recall')
metric3 = mx.metric.CustomMetric(feval=entity_f1, name='entity f1 score')
metric4 = mx.metric.Accuracy()
return mx.metric.CompositeEvalMetric([metric4, metric1, metric2, metric3])