blob: 9dc13e4b69575a49c420b0b6d47403fc000e3cc1 [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import heapq
import math
import numpy as np
class Caption(object):
"""
A complete or partial caption object
"""
def __init__(self, sentence, state, logprob, score):
"""Initializes the Caption"""
# list of word_ids in the caption
self.sentence = sentence
# model state after generating the previous word
self.state = state
# log probability of the caption
self.logprob = logprob
# score of the caption
self.score = score
def __cmp__(self, other):
"""Compares Captions by score"""
assert isinstance(other, Caption)
if self.score == other.score:
return 0
elif self.score < other.score:
return -1
else:
return 1
# for Python 3 compatibility (__cmp__ is deprecated).
def __lt__(self, other):
assert isinstance(other, Caption)
return self.score < other.score
# also for Python 3 compatibility.
def __eq__(self, other):
assert isinstance(other, Caption)
return self.score == other.score
class TopN(object):
"""Maintains the top N elements of an incrementally provided set"""
def __init__(self, n):
self._n = n
self._data = []
def size(self):
assert self._data is not None
return len(self._data)
def push(self, x):
"""Pushes a new element"""
assert self._data is not None
if len(self._data) < self._n:
heapq.heappush(self._data, x)
else:
heapq.heappushpop(self._data, x)
def extract(self, sort=False):
"""
Extracts all elements from the TopN. This is a destructive operation,
The only method that can be called immediately after extract() is reset()
"""
assert self._data is not None
data = self._data
self._data = None
if sort:
data.sort(reverse=True)
return data
def reset(self):
"""Returns the TopN to an empty state"""
self._data = []
class CaptionGenerator(object):
"""
Class to generate captions from an image-to-text model
"""
def __init__(self,
model,
vocab,
beam_size,
max_caption_length,
length_normalization_factor=0.0):
self.vocab = vocab
self.model = model
self.beam_size = beam_size
self.max_caption_length = max_caption_length
self.length_normalization_factor = length_normalization_factor
def beam_search(self, sess, encoded_image):
"""Runs beam search caption generation on a single image"""
# feed in the image to get the initial state.
initial_state = self.model.feed_image(sess, encoded_image)
initial_beam = Caption(
sentence=[self.vocab.start_id],
state=initial_state[0],
logprob=0.0,
score=0.0)
partial_captions = TopN(self.beam_size)
partial_captions.push(initial_beam)
complete_captions = TopN(self.beam_size)
# run beam search.
for _ in range(self.max_caption_length - 1):
partial_captions_list = partial_captions.extract()
partial_captions.reset()
input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
state_feed = np.array([c.state for c in partial_captions_list])
softmax, new_states = self.model.inference_step(sess, input_feed, state_feed)
for i, partial_caption in enumerate(partial_captions_list):
word_probabilities = softmax[i]
state = new_states[i]
# for this partial caption, get the beam_size most probable next words.
words_and_probs = list(enumerate(word_probabilities))
words_and_probs.sort(key=lambda x: -x[1])
words_and_probs = words_and_probs[0:self.beam_size]
# each next word gives a new partial caption.
for w, p in words_and_probs:
if p < 1e-12:
continue # avoid log(0).
sentence = partial_caption.sentence + [w]
logprob = partial_caption.logprob + math.log(p)
score = logprob
if w == self.vocab.end_id:
if self.length_normalization_factor > 0:
score /= len(sentence) ** self.length_normalization_factor
beam = Caption(sentence, state, logprob, score)
complete_captions.push(beam)
else:
beam = Caption(sentence, state, logprob, score)
partial_captions.push(beam)
if partial_captions.size() == 0:
# we have run out of partial candidates; happens when beam_size = 1.
break
# if we have no complete captions then fall back to the partial captions,
# but never output a mixture of complete and partial captions because a
# partial caption could have a higher score than all the complete captions
if not complete_captions.size():
complete_captions = partial_captions
return complete_captions.extract(sort=True)