Sennrich, Rico, Barry Haddow, and Alexandra Birch. "Neural Machine Translation of Rare Words with Subword Units." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Vol. 1. 2016. http://www.aclweb.org/anthology/P16-1162
Implementation avaiable on Github: https://github.com/rsennrich/subword-nmt
import re, collections
def get_stats(vocab):
"""Compute frequencies of adjacent pairs of symbols."""
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
from IPython.display import display, Markdown, Latex
train_data = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
bpe_codes = {}
bpe_codes_reverse = {}
num_merges = 10
for i in range(num_merges):
display(Markdown("### Iteration {}".format(i + 1)))
pairs = get_stats(train_data)
best = max(pairs, key=pairs.get)
train_data = merge_vocab(best, train_data)
bpe_codes[best] = i
bpe_codes_reverse[best[0] + best[1]] = best
print("new merge: {}".format(best))
print("train data: {}".format(train_data))
While possible:
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as a tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def encode(orig):
"""Encode word based on list of BPE merge operations, which are applied consecutively"""
word = tuple(orig) + ('</w>',)
display(Markdown("__word split into characters:__ <tt>{}</tt>".format(word)))
pairs = get_pairs(word)
if not pairs:
return orig
iteration = 0
while True:
iteration += 1
display(Markdown("__Iteration {}:__".format(iteration)))
print("bigrams in the word: {}".format(pairs))
bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
print("candidate for merging: {}".format(bigram))
if bigram not in bpe_codes:
display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
print("word after merging: {}".format(word))
if len(word) == 1:
break
else:
pairs = get_pairs(word)
# don't print end-of-word symbols
if word[-1] == '</w>':
word = word[:-1]
elif word[-1].endswith('</w>'):
word = word[:-1] + (word[-1].replace('</w>',''),)
return word
The word lowest was not in the training data. Both low and ending est are the learned merges, so the word splits as we would expect.
encode("lowest")