blob: d9e80d882390ca30a36745ce8a358c500ee476bf [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
def vis_detection(im_orig, detections, class_names, thresh=0.7):
"""visualize [cls, conf, x1, y1, x2, y2]"""
import matplotlib.pyplot as plt
import random
plt.imshow(im_orig)
colors = [(random.random(), random.random(), random.random()) for _ in class_names]
for [cls, conf, x1, y1, x2, y2] in detections:
cls = int(cls)
if cls > 0 and conf > thresh:
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
fill=False, edgecolor=colors[cls], linewidth=3.5)
plt.gca().add_patch(rect)
plt.gca().text(x1, y1 - 2, '{:s} {:.3f}'.format(class_names[cls], conf),
bbox=dict(facecolor=colors[cls], alpha=0.5), fontsize=12, color='white')
plt.show()