blob: 4072cc84f684d9e68c51ca40471eef09ddf2b5dd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
from __future__ import absolute_import
from __future__ import print_function
from collections import Counter
from common import assertRaises
from mxnet import ndarray as nd
from mxnet.test_utils import *
from mxnet.contrib import text
def _get_test_str_of_tokens(token_delim, seq_delim):
seq1 = token_delim + token_delim.join(['Life', 'is', 'great', '!']) + token_delim + seq_delim
seq2 = token_delim + token_delim.join(['life', 'is', 'good', '.']) + token_delim + seq_delim
seq3 = token_delim + token_delim.join(['life', "isn't", 'bad', '.']) + token_delim + seq_delim
seqs = seq1 + seq2 + seq3
return seqs
def _test_count_tokens_from_str_with_delims(token_delim, seq_delim):
source_str = _get_test_str_of_tokens(token_delim, seq_delim)
cnt1 = text.utils.count_tokens_from_str(
source_str, token_delim, seq_delim, to_lower=False)
assert cnt1 == Counter(
{'is': 2, 'life': 2, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1, "isn't": 1,
'bad': 1})
cnt2 = text.utils.count_tokens_from_str(
source_str, token_delim, seq_delim, to_lower=True)
assert cnt2 == Counter(
{'life': 3, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1, "isn't": 1, 'bad': 1})
counter_to_update = Counter({'life': 2})
cnt3 = text.utils.count_tokens_from_str(
source_str, token_delim, seq_delim, to_lower=False,
counter_to_update=counter_to_update.copy())
assert cnt3 == Counter(
{'is': 2, 'life': 4, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1, "isn't": 1,
'bad': 1})
cnt4 = text.utils.count_tokens_from_str(
source_str, token_delim, seq_delim, to_lower=True,
counter_to_update=counter_to_update.copy())
assert cnt4 == Counter(
{'life': 5, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1, "isn't": 1, 'bad': 1})
def test_count_tokens_from_str():
_test_count_tokens_from_str_with_delims(' ', '\n')
_test_count_tokens_from_str_with_delims('IS', 'LIFE')
def test_tokens_to_indices():
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
vocab = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
i1 = vocab.to_indices('c')
assert i1 == 1
i2 = vocab.to_indices(['c'])
assert i2 == [1]
i3 = vocab.to_indices(['<unk>', 'non-exist'])
assert i3 == [0, 0]
i4 = vocab.to_indices(['a', 'non-exist', 'a', 'b'])
assert i4 == [3, 0, 3, 2]
def test_indices_to_tokens():
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
vocab = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1,
unknown_token='<unknown>', reserved_tokens=None)
i1 = vocab.to_tokens(1)
assert i1 == 'c'
i2 = vocab.to_tokens([1])
assert i2 == ['c']
i3 = vocab.to_tokens([0, 0])
assert i3 == ['<unknown>', '<unknown>']
i4 = vocab.to_tokens([3, 0, 3, 2])
assert i4 == ['a', '<unknown>', 'a', 'b']
assertRaises(ValueError, vocab.to_tokens, 100)
def test_download_embed():
@text.embedding.register
class Test(text.embedding._TokenEmbedding):
# 33 bytes.
pretrained_file_name_sha1 = \
{'embedding_test.vec': '29b9a6511cf4b5aae293c44a9ec1365b74f2a2f8'}
namespace = 'test'
def __init__(self, embedding_root='embeddings', init_unknown_vec=nd.zeros, **kwargs):
pretrained_file_name = 'embedding_test.vec'
Test._check_pretrained_file_names(pretrained_file_name)
super(Test, self).__init__(**kwargs)
pretrained_file_path = Test._get_pretrained_file(embedding_root, pretrained_file_name)
self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
test_embed = text.embedding.create('test')
assert test_embed.token_to_idx['hello'] == 1
assert test_embed.token_to_idx['world'] == 2
assert_almost_equal(test_embed.idx_to_vec[1].asnumpy(), (nd.arange(5) + 1).asnumpy())
assert_almost_equal(test_embed.idx_to_vec[2].asnumpy(), (nd.arange(5) + 6).asnumpy())
assert_almost_equal(test_embed.idx_to_vec[0].asnumpy(), nd.zeros((5,)).asnumpy())
def _mk_my_pretrain_file(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
seqs = seq1 + seq2
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def _mk_my_pretrain_file2(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04', '0.05']) + '\n'
seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
seqs = seq1 + seq2
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def _mk_my_pretrain_file3(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
seq3 = token_delim.join(['<unk1>', '1.1', '1.2', '1.3', '1.4',
'1.5']) + '\n'
seqs = seq1 + seq2 + seq3
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def _mk_my_pretrain_file4(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04', '0.05']) + '\n'
seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
seq3 = token_delim.join(['<unk2>', '0.11', '0.12', '0.13', '0.14', '0.15']) + '\n'
seqs = seq1 + seq2 + seq3
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def _mk_my_invalid_pretrain_file(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
seq3 = token_delim.join(['c']) + '\n'
seqs = seq1 + seq2 + seq3
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def _mk_my_invalid_pretrain_file2(path, token_delim, pretrain_file):
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
seq3 = token_delim.join(['c', '0.6', '0.7', '0.8']) + '\n'
seqs = seq1 + seq2 + seq3
with open(os.path.join(path, pretrain_file), 'w') as fout:
fout.write(seqs)
def test_custom_embed():
embed_root = 'embeddings'
embed_name = 'my_embed'
elem_delim = '\t'
pretrain_file = 'my_pretrain_file.txt'
_mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
my_embed = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim)
assert len(my_embed) == 3
assert my_embed.vec_len == 5
assert my_embed.token_to_idx['a'] == 1
assert my_embed.idx_to_token[1] == 'a'
first_vec = my_embed.idx_to_vec[0]
assert_almost_equal(first_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
unk_vec = my_embed.get_vecs_by_tokens('A')
assert_almost_equal(unk_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
a_vec = my_embed.get_vecs_by_tokens('A', lower_case_backup=True)
assert_almost_equal(a_vec.asnumpy(), np.array([0.1, 0.2, 0.3, 0.4, 0.5]))
unk_vecs = my_embed.get_vecs_by_tokens(['<unk$unk@unk>', '<unk$unk@unk>'])
assert_almost_equal(unk_vecs.asnumpy(), np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
# Test loaded unknown vectors.
pretrain_file2 = 'my_pretrain_file2.txt'
_mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim, pretrain_file2)
pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file2)
my_embed2 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim,
init_unknown_vec=nd.ones, unknown_token='<unk>')
unk_vec2 = my_embed2.get_vecs_by_tokens('<unk>')
assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
unk_vec2 = my_embed2.get_vecs_by_tokens('<unk$unk@unk>')
assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
my_embed3 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim,
init_unknown_vec=nd.ones, unknown_token='<unk1>')
unk_vec3 = my_embed3.get_vecs_by_tokens('<unk1>')
assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
unk_vec3 = my_embed3.get_vecs_by_tokens('<unk$unk@unk>')
assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
# Test error handling.
invalid_pretrain_file = 'invalid_pretrain_file.txt'
_mk_my_invalid_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
invalid_pretrain_file)
pretrain_file_path = os.path.join(embed_root, embed_name, invalid_pretrain_file)
assertRaises(AssertionError, text.embedding.CustomEmbedding, pretrain_file_path, elem_delim)
invalid_pretrain_file2 = 'invalid_pretrain_file2.txt'
_mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim,
invalid_pretrain_file2)
pretrain_file_path = os.path.join(embed_root, embed_name, invalid_pretrain_file2)
assertRaises(AssertionError, text.embedding.CustomEmbedding, pretrain_file_path, elem_delim)
def test_vocabulary():
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
assert len(v1) == 5
assert v1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
assert v1.idx_to_token[1] == 'c'
assert v1.unknown_token == '<unk>'
assert v1.reserved_tokens is None
v2 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, unknown_token='<unk>',
reserved_tokens=None)
assert len(v2) == 3
assert v2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
assert v2.idx_to_token[1] == 'c'
assert v2.unknown_token == '<unk>'
assert v2.reserved_tokens is None
v3 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=100, unknown_token='<unk>',
reserved_tokens=None)
assert len(v3) == 1
assert v3.token_to_idx == {'<unk>': 0}
assert v3.idx_to_token[0] == '<unk>'
assert v3.unknown_token == '<unk>'
assert v3.reserved_tokens is None
v4 = text.vocab.Vocabulary(counter, most_freq_count=2, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
assert len(v4) == 3
assert v4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
assert v4.idx_to_token[1] == 'c'
assert v4.unknown_token == '<unk>'
assert v4.reserved_tokens is None
v5 = text.vocab.Vocabulary(counter, most_freq_count=3, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
assert len(v5) == 4
assert v5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
assert v5.idx_to_token[1] == 'c'
assert v5.unknown_token == '<unk>'
assert v5.reserved_tokens is None
v6 = text.vocab.Vocabulary(counter, most_freq_count=100, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
assert len(v6) == 5
assert v6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
'some_word$': 4}
assert v6.idx_to_token[1] == 'c'
assert v6.unknown_token == '<unk>'
assert v6.reserved_tokens is None
v7 = text.vocab.Vocabulary(counter, most_freq_count=1, min_freq=2, unknown_token='<unk>',
reserved_tokens=None)
assert len(v7) == 2
assert v7.token_to_idx == {'<unk>': 0, 'c': 1}
assert v7.idx_to_token[1] == 'c'
assert v7.unknown_token == '<unk>'
assert v7.reserved_tokens is None
assertRaises(AssertionError, text.vocab.Vocabulary, counter, most_freq_count=None,
min_freq=0, unknown_token='<unknown>', reserved_tokens=['b'])
assertRaises(AssertionError, text.vocab.Vocabulary, counter, most_freq_count=None,
min_freq=1, unknown_token='<unknown>', reserved_tokens=['b', 'b'])
assertRaises(AssertionError, text.vocab.Vocabulary, counter, most_freq_count=None,
min_freq=1, unknown_token='<unknown>', reserved_tokens=['b', '<unknown>'])
v8 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unknown>',
reserved_tokens=['b'])
assert len(v8) == 5
assert v8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3, 'some_word$': 4}
assert v8.idx_to_token[1] == 'b'
assert v8.unknown_token == '<unknown>'
assert v8.reserved_tokens == ['b']
v9 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, unknown_token='<unk>',
reserved_tokens=['b', 'a'])
assert len(v9) == 4
assert v9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
assert v9.idx_to_token[1] == 'b'
assert v9.unknown_token == '<unk>'
assert v9.reserved_tokens == ['b', 'a']
v10 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=100, unknown_token='<unk>',
reserved_tokens=['b', 'c'])
assert len(v10) == 3
assert v10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
assert v10.idx_to_token[1] == 'b'
assert v10.unknown_token == '<unk>'
assert v10.reserved_tokens == ['b', 'c']
v11 = text.vocab.Vocabulary(counter, most_freq_count=1, min_freq=2, unknown_token='<unk>',
reserved_tokens=['<pad>', 'b'])
assert len(v11) == 4
assert v11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
assert v11.idx_to_token[1] == '<pad>'
assert v11.unknown_token == '<unk>'
assert v11.reserved_tokens == ['<pad>', 'b']
v12 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, unknown_token='b',
reserved_tokens=['<pad>'])
assert len(v12) == 3
assert v12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
assert v12.idx_to_token[1] == '<pad>'
assert v12.unknown_token == 'b'
assert v12.reserved_tokens == ['<pad>']
v13 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, unknown_token='a',
reserved_tokens=['<pad>'])
assert len(v13) == 4
assert v13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
assert v13.idx_to_token[1] == '<pad>'
assert v13.unknown_token == 'a'
assert v13.reserved_tokens == ['<pad>']
counter_tuple = Counter([('a', 'a'), ('b', 'b'), ('b', 'b'), ('c', 'c'), ('c', 'c'), ('c', 'c'),
('some_word$', 'some_word$')])
v14 = text.vocab.Vocabulary(counter_tuple, most_freq_count=None, min_freq=1,
unknown_token=('<unk>', '<unk>'), reserved_tokens=None)
assert len(v14) == 5
assert v14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1, ('b', 'b'): 2, ('a', 'a'): 3,
('some_word$', 'some_word$'): 4}
assert v14.idx_to_token[1] == ('c', 'c')
assert v14.unknown_token == ('<unk>', '<unk>')
assert v14.reserved_tokens is None
def test_custom_embedding_with_vocabulary():
embed_root = 'embeddings'
embed_name = 'my_embed'
elem_delim = '\t'
pretrain_file = 'my_pretrain_file1.txt'
_mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=['<pad>'])
e1 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim, init_unknown_vec=nd.ones,
vocabulary=v1)
assert e1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4, 'some_word$': 5}
assert e1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0.6, 0.7, 0.8, 0.9, 1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[1, 1, 1, 1, 1]])
)
assert e1.vec_len == 5
assert e1.reserved_tokens == ['<pad>']
assert_almost_equal(e1.get_vecs_by_tokens('c').asnumpy(),
np.array([1, 1, 1, 1, 1])
)
assert_almost_equal(e1.get_vecs_by_tokens(['c']).asnumpy(),
np.array([[1, 1, 1, 1, 1]])
)
assert_almost_equal(e1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[1, 1, 1, 1, 1]])
)
assert_almost_equal(e1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1]])
)
assert_almost_equal(e1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
np.array([[1, 1, 1, 1, 1],
[0.6, 0.7, 0.8, 0.9, 1]])
)
assert_almost_equal(e1.get_vecs_by_tokens(['A', 'b'], lower_case_backup=True).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1]])
)
e1.update_token_vectors(['a', 'b'],
nd.array([[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3]])
)
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
assertRaises(ValueError, e1.update_token_vectors, 'unknown$$$', nd.array([0, 0, 0, 0, 0]))
assertRaises(AssertionError, e1.update_token_vectors, '<unk>',
nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
assertRaises(AssertionError, e1.update_token_vectors, '<unk>', nd.array([0]))
e1.update_token_vectors(['<unk>'], nd.array([0, 0, 0, 0, 0]))
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
e1.update_token_vectors(['<unk>'], nd.array([[10, 10, 10, 10, 10]]))
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[10, 10, 10, 10, 10],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
e1.update_token_vectors('<unk>', nd.array([0, 0, 0, 0, 0]))
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
e1.update_token_vectors('<unk>', nd.array([[10, 10, 10, 10, 10]]))
assert_almost_equal(e1.idx_to_vec.asnumpy(),
np.array([[10, 10, 10, 10, 10],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
def test_composite_embedding_with_one_embedding():
embed_root = 'embeddings'
embed_name = 'my_embed'
elem_delim = '\t'
pretrain_file = 'my_pretrain_file1.txt'
_mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file)
pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
my_embed = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim,
init_unknown_vec=nd.ones)
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=['<pad>'])
ce1 = text.embedding.CompositeEmbedding(v1, my_embed)
assert ce1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4, 'some_word$': 5}
assert ce1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0.6, 0.7, 0.8, 0.9, 1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[1, 1, 1, 1, 1]])
)
assert ce1.vec_len == 5
assert ce1.reserved_tokens == ['<pad>']
assert_almost_equal(ce1.get_vecs_by_tokens('c').asnumpy(),
np.array([1, 1, 1, 1, 1])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['c']).asnumpy(),
np.array([[1, 1, 1, 1, 1]])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[1, 1, 1, 1, 1]])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1]])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
np.array([[1, 1, 1, 1, 1],
[0.6, 0.7, 0.8, 0.9, 1]])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['A', 'b'], lower_case_backup=True).asnumpy(),
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1]])
)
ce1.update_token_vectors(['a', 'b'],
nd.array([[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3]])
)
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
assertRaises(ValueError, ce1.update_token_vectors, 'unknown$$$', nd.array([0, 0, 0, 0, 0]))
assertRaises(AssertionError, ce1.update_token_vectors, '<unk>',
nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
assertRaises(AssertionError, ce1.update_token_vectors, '<unk>', nd.array([0]))
ce1.update_token_vectors(['<unk>'], nd.array([0, 0, 0, 0, 0]))
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
ce1.update_token_vectors(['<unk>'], nd.array([[10, 10, 10, 10, 10]]))
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[10, 10, 10, 10, 10],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
ce1.update_token_vectors('<unk>', nd.array([0, 0, 0, 0, 0]))
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
ce1.update_token_vectors('<unk>', nd.array([[10, 10, 10, 10, 10]]))
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[10, 10, 10, 10, 10],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[3, 3, 3, 3, 3],
[2, 2, 2, 2, 2],
[1, 1, 1, 1, 1]])
)
def test_composite_embedding_with_two_embeddings():
embed_root = '.'
embed_name = 'my_embed'
elem_delim = '\t'
pretrain_file1 = 'my_pretrain_file1.txt'
pretrain_file2 = 'my_pretrain_file2.txt'
_mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, pretrain_file1)
_mk_my_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim, pretrain_file2)
pretrain_file_path1 = os.path.join(embed_root, embed_name, pretrain_file1)
pretrain_file_path2 = os.path.join(embed_root, embed_name, pretrain_file2)
my_embed1 = text.embedding.CustomEmbedding(pretrain_file_path1, elem_delim,
init_unknown_vec=nd.ones)
my_embed2 = text.embedding.CustomEmbedding(pretrain_file_path2, elem_delim)
counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
ce1 = text.embedding.CompositeEmbedding(v1, [my_embed1, my_embed2])
assert ce1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 'some_word$': 4}
assert ce1.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
[0.1, 0.2, 0.3, 0.4, 0.5,
0.01, 0.02, 0.03, 0.04, 0.05],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
)
assert ce1.vec_len == 10
assert ce1.reserved_tokens is None
assert_almost_equal(ce1.get_vecs_by_tokens('c').asnumpy(),
np.array([1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1])
)
assert_almost_equal(ce1.get_vecs_by_tokens(['b', 'not_exist']).asnumpy(),
np.array([[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
)
ce1.update_token_vectors(['a', 'b'],
nd.array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
)
assert_almost_equal(ce1.idx_to_vec.asnumpy(),
np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
)
# Test loaded unknown tokens
pretrain_file3 = 'my_pretrain_file3.txt'
pretrain_file4 = 'my_pretrain_file4.txt'
_mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim, pretrain_file3)
_mk_my_pretrain_file4(os.path.join(embed_root, embed_name), elem_delim, pretrain_file4)
pretrain_file_path3 = os.path.join(embed_root, embed_name, pretrain_file3)
pretrain_file_path4 = os.path.join(embed_root, embed_name, pretrain_file4)
my_embed3 = text.embedding.CustomEmbedding(pretrain_file_path3, elem_delim,
init_unknown_vec=nd.ones, unknown_token='<unk1>')
my_embed4 = text.embedding.CustomEmbedding(pretrain_file_path4, elem_delim,
unknown_token='<unk2>')
v2 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk>',
reserved_tokens=None)
ce2 = text.embedding.CompositeEmbedding(v2, [my_embed3, my_embed4])
assert_almost_equal(ce2.idx_to_vec.asnumpy(),
np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.06, 0.07, 0.08, 0.09, 0.1],
[0.6, 0.7, 0.8, 0.9, 1,
0.11, 0.12, 0.13, 0.14, 0.15],
[0.1, 0.2, 0.3, 0.4, 0.5,
0.01, 0.02, 0.03, 0.04, 0.05],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15]])
)
v3 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk1>',
reserved_tokens=None)
ce3 = text.embedding.CompositeEmbedding(v3, [my_embed3, my_embed4])
assert_almost_equal(ce3.idx_to_vec.asnumpy(),
np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.06, 0.07, 0.08, 0.09, 0.1],
[0.6, 0.7, 0.8, 0.9, 1,
0.11, 0.12, 0.13, 0.14, 0.15],
[0.1, 0.2, 0.3, 0.4, 0.5,
0.01, 0.02, 0.03, 0.04, 0.05],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15]])
)
v4 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, unknown_token='<unk2>',
reserved_tokens=None)
ce4 = text.embedding.CompositeEmbedding(v4, [my_embed3, my_embed4])
assert_almost_equal(ce4.idx_to_vec.asnumpy(),
np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.06, 0.07, 0.08, 0.09, 0.1],
[0.6, 0.7, 0.8, 0.9, 1,
0.11, 0.12, 0.13, 0.14, 0.15],
[0.1, 0.2, 0.3, 0.4, 0.5,
0.01, 0.02, 0.03, 0.04, 0.05],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15]])
)
counter2 = Counter(['b', 'b', 'c', 'c', 'c', 'some_word$'])
v5 = text.vocab.Vocabulary(counter2, most_freq_count=None, min_freq=1, unknown_token='a',
reserved_tokens=None)
ce5 = text.embedding.CompositeEmbedding(v5, [my_embed3, my_embed4])
assert ce5.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
assert ce5.idx_to_token == ['a', 'c', 'b', 'some_word$']
assert_almost_equal(ce5.idx_to_vec.asnumpy(),
np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.06, 0.07, 0.08, 0.09, 0.1],
[0.6, 0.7, 0.8, 0.9, 1,
0.11, 0.12, 0.13, 0.14, 0.15],
[1.1, 1.2, 1.3, 1.4, 1.5,
0.11, 0.12, 0.13, 0.14, 0.15]])
)
def test_get_and_pretrain_file_names():
assert len(text.embedding.get_pretrained_file_names(
embedding_name='fasttext')) == 327
assert len(text.embedding.get_pretrained_file_names(embedding_name='glove')) == 10
reg = text.embedding.get_pretrained_file_names(embedding_name=None)
assert len(reg['glove']) == 10
assert len(reg['fasttext']) == 327
assertRaises(KeyError, text.embedding.get_pretrained_file_names, 'unknown$$')
if __name__ == '__main__':
import nose
nose.runmodule()