Byte Pair Encoding

Sennrich, Rico, Barry Haddow, and Alexandra Birch. "Neural Machine Translation of Rare Words with Subword Units." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Vol. 1. 2016. http://www.aclweb.org/anthology/P16-1162

Implementation avaiable on Github: https://github.com/rsennrich/subword-nmt

Training algoritm

  • Compute freuencies of all words in the training corpus
  • Start with vocabulary that consits from singleton symbols from training corpus
  • To get vocabulary of n merges, iterate n times:
    1. Get the most frequent pair of symbols in the training data
    2. Add the pair into list of merges
    3. Add the merged symbol into vocabulary
In [2]:
import re, collections

def get_stats(vocab):
    """Compute frequencies of adjacent pairs of symbols."""
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out
In [3]:
from IPython.display import display, Markdown, Latex

train_data = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

bpe_codes = {}
bpe_codes_reverse = {}

num_merges = 10

for i in range(num_merges):
    display(Markdown("### Iteration {}".format(i + 1)))
    pairs = get_stats(train_data)
    best = max(pairs, key=pairs.get)
    train_data = merge_vocab(best, train_data)
    
    bpe_codes[best] = i
    bpe_codes_reverse[best[0] + best[1]] = best
    
    print("new merge: {}".format(best))
    print("train data: {}".format(train_data))
    

Iteration 1

new merge: ('t', '</w>')
train data: {'n e w e s t</w>': 6, 'l o w e r </w>': 2, 'l o w </w>': 5, 'w i d e s t</w>': 3}

Iteration 2

new merge: ('s', 't</w>')
train data: {'l o w e r </w>': 2, 'n e w e st</w>': 6, 'w i d e st</w>': 3, 'l o w </w>': 5}

Iteration 3

new merge: ('e', 'st</w>')
train data: {'l o w e r </w>': 2, 'l o w </w>': 5, 'n e w est</w>': 6, 'w i d est</w>': 3}

Iteration 4

new merge: ('l', 'o')
train data: {'w i d est</w>': 3, 'lo w </w>': 5, 'n e w est</w>': 6, 'lo w e r </w>': 2}

Iteration 5

new merge: ('lo', 'w')
train data: {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

Iteration 6

new merge: ('e', 'w')
train data: {'w i d est</w>': 3, 'low </w>': 5, 'low e r </w>': 2, 'n ew est</w>': 6}

Iteration 7

new merge: ('ew', 'est</w>')
train data: {'n ewest</w>': 6, 'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3}

Iteration 8

new merge: ('n', 'ewest</w>')
train data: {'newest</w>': 6, 'low </w>': 5, 'low e r </w>': 2, 'w i d est</w>': 3}

Iteration 9

new merge: ('low', '</w>')
train data: {'newest</w>': 6, 'w i d est</w>': 3, 'low</w>': 5, 'low e r </w>': 2}

Iteration 10

new merge: ('i', 'd')
train data: {'low e r </w>': 2, 'newest</w>': 6, 'w id est</w>': 3, 'low</w>': 5}

Apply BPE

While possible:

  1. Get all symbol bigrams in the word.
  2. Find a symbol pair that appeared the first among the symbol merges.
  3. Apply the merge on the word.
In [4]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('</w>',)
    display(Markdown("__word split into characters:__ <tt>{}</tt>".format(word)))

    pairs = get_pairs(word)    

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        display(Markdown("__Iteration {}:__".format(iteration)))
        
        print("bigrams in the word: {}".format(pairs))
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print("candidate for merging: {}".format(bigram))
        if bigram not in bpe_codes:
            display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # don't print end-of-word symbols
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1].replace('</w>',''),)
   
    return word

The word lowest was not in the training data. Both low and ending est are the learned merges, so the word splits as we would expect.

In [8]:
encode("lowest")

word split into characters: ('l', 'o', 'w', 'e', 's', 't', '')

Iteration 1:

bigrams in the word: {('l', 'o'), ('t', '</w>'), ('o', 'w'), ('w', 'e'), ('e', 's'), ('s', 't')}
candidate for merging: ('t', '</w>')
word after merging: ('l', 'o', 'w', 'e', 's', 't</w>')

Iteration 2:

bigrams in the word: {('l', 'o'), ('w', 'e'), ('e', 's'), ('s', 't</w>'), ('o', 'w')}
candidate for merging: ('s', 't</w>')
word after merging: ('l', 'o', 'w', 'e', 'st</w>')

Iteration 3:

bigrams in the word: {('l', 'o'), ('w', 'e'), ('e', 'st</w>'), ('o', 'w')}
candidate for merging: ('e', 'st</w>')
word after merging: ('l', 'o', 'w', 'est</w>')

Iteration 4:

bigrams in the word: {('l', 'o'), ('w', 'est</w>'), ('o', 'w')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'w', 'est</w>')

Iteration 5:

bigrams in the word: {('w', 'est</w>'), ('lo', 'w')}
candidate for merging: ('lo', 'w')
word after merging: ('low', 'est</w>')

Iteration 6:

bigrams in the word: {('low', 'est</w>')}
candidate for merging: ('low', 'est</w>')

Candidate not in BPE merges, algorithm stops.

Out[8]:
('low', 'est')