Introduction to HuggingFace Tokenizers

Author

Cameron Barker

Published

February 25, 2025

Introduction

In this notebook, we will be exploring the HuggingFace Tokenizers library. This library is used to preprocess text data for use in NLP models. We will cover the basics of training a BPE tokenizer similar to the one used in Llama 3 and then use what we have learned to design a custom character-level tokenizer.

We will start our journey with the Llama 3.1 8B HuggingFace repository where we will find the tokenizer_config.json, tokenizer.json and special_tokens_map.json files that define the tokenizer used in Llama 3.1 8B.

It is worth noting that HuggingFace Transformers has two main tokenizer base classes: PreTrainedTokenizer and PreTrainedTokenizerFast, both of which inherit from PreTrainedTokenizerBase.

PreTrainedTokenizer is the legacy inbuilt python tokenizer while PreTrainedTokenizerFast is the newer Rust-based tokenizer from the HuggingFace Tokenizers library. When using PreTrainedTokenizerFast several settings and configs are vestigial such as special_tokens_map.json. Hence we will focus only on the necessary configurations to recreate the tokenizer used in Llama 3.1 8B rather than the full set of configurations.

tokenizer_config.json
{
  "bos_token": "<|begin_of_text|>",
  "eos_token": "<|end_of_text|>",
  "clean_up_tokenization_spaces": true,
  "model_input_names": ["input_ids", "attention_mask"],
  "model_max_length": 131072,
  "tokenizer_class": "PreTrainedTokenizerFast"
}

The tokenizer_config.json file tells AutoTokenizer how to load the tokenizer and configures the HF Transformers Tokenizer. Note that tokenizer_class tells AutoTokenizer to use PreTrainedTokenizerFast and that the LlamaTokenizerFast HF llama implementation is not used.

The rest of the config parameters are passed into the class as arguments. You can define the special tokens defined in the SpecialTokensMixin.

clean_up_tokenization_spaces removes leading whitespace for certain punctuation and contractions after decoding from token IDs back to text. Not really needed as it does not affect token encoding and can be build directly into the decoder of the tokenizer.

model_input_names tells the tokenizer which inputs to generate for the model. For NLP models, the inputs are typically input_ids, token_type_ids, and attention_mask.

model_max_length is the maximum length of the input sequence that the tokenizer can handle. Somewhat arbitrary as a similar param is also set in the model config to greater effect.

tokenizer.json
{
  "version": "1.0",
  "trucation": null,
  "padding": null,
  "added_tokens": ...,
  "normalizer": null,
  "pre_tokenizer": ...,
  "post_processor": ...,
  "decoder": ...,
  "model": ...
}

This file is passed to the Tokenizer class of HF Tokenizers library to create the fast tokenizer itself.

Loosely, each field in the tokenizer.json file acts as a config for a corresponding method in the Tokenizer class.

Tokenizer Model

Here we will define the type of model used in the tokenizer. BPE, Unigram, WordLevel and WordPiece are supported. Each model is an reimplementation of previous popular tokenizer implementations and retains the same quirks and peculiarities of each original implementation. Hence while algorithmically, each model is somewhat similar with the main difference being the rules for merging tokens. In practice, each tokenizer model can feel quite different to work, especially with regards to (sub)word boundaries.

tokenizer.json
"model": {
  "type": "BPE",
  "dropout": null,
  "unk_token": null,
  "continuing_subword_prefix": null,
  "end_of_word_suffix": null,
  "fuse_unk": false,
  "byte_fallback": false,
  "ignore_merges": true,
  "vocab": {
      ...
  },
  "merges": {
      ...
  }
}

Below we can see how the parameters used above in the tokenizer.json file are used to create the tokenizer model. As we are going to train our own tokenizer, we wont pass in the vocab and merges parameters as they will be learned during training.

We will also take this opportunity to define the vocab_size to 128256 which can be found in the config.json file of the model or by incrementing the largest token id in tokenizer.json. In Llama 3 this corresponds to 128000 tokens plus 256 special tokens. Only the first two special tokens are used, the rest are reserved for future use, such as to help structure text for instruction fine tuning. In such case you can rename the reserved special tokens or even retrain a PreTrainedTokenizerFast using the train_new_from_iterator method.

from tokenizers import Tokenizer, models

test_str = "The quick brown fox jumps over the lazy dog 😂"
vocab_size = 128256
non_special_vocab_size = 128000
tokenizer = Tokenizer(
    models.BPE(
        dropout=None,
        unk_token=None,
        # The next 2 args default to None,
        # but we can't set them to None directly
        # due to an oddity in the rust to python bindings
        # continuing_subword_prefix=None,
        # end_of_word_suffix=None,
        fuse_unk=False,
        byte_fallback=False,
        ignore_merges=True,
    )
)

Normalizer

The normalizer is used to preprocess the text before tokenization. In general it performs some sort of string replacement.

While Llama 3 does not use a normalizer, we will show an example of how to define one below.

"""
from tokenizers import normalizers

tokenizer.normalizer = normalizers.Sequence(
    [
        normalizers.Lowercase(), # Lowercase the input
        normalizers.Strip(), # Remove leading/trailing whitespaces
    ]
)
"""

Pretokenizing

The pretokenizer is used to split the text into tokens before the tokenizer model is applied. Simple rules define where the text will always be split. This also applies to the training corpus so the tokenizer won’t learn to merge tokens across these split boundaries.

In Llama 3, the pretokenizer is a sequence of two pretokenizers: Split and ByteLevel. Split is used to apply a regex pattern to the text. See https://regex101.com/r/m6gEfJ/1 for a breakdown of the regex pattern used in Llama 3. Loosely, the pattern splits before whitespace, punctuation and numbers.

The result of this splitting can be seen bellow in pre_tokenized_str however you will also notice that spaces have been replaces with Ġ which is being caused by the ByteLevel pretokenizer. Additionally the the emoji is split into several tokens. This is because the ByteLevel pretokenizer treats each byte as a token rather than each character. In UTF-8 encoding, on top of the standard ASCII characters, additional characters can be represented with multi-byte sequences. Individually, bytes usually reserved for these multi-bytes sequences can’t be printed as characters. To circumvent this, the bytes are remapped so that unprintable characters are replaced with the next available printable character. HF Tokenizers borrows from the OpenAI GPT-2 implementation where they define a list of bytes to keep and then fill in the rest of the printable characters with the next printable character in the list. OpenAI chose to also remove the space character from the list of printable characters because other parts of their codebase assumed the tokenizer had removed spaces. Space is the 32nd UTF-8 character and Ġ is the 32nd multi-byte printable (at least under OpenAIs interpretation) character.

You may have noticed that the trim_offsets parameter is set to true for the ByteLevel pretokenizer, however is not set in the code below. If you are running the code at home you may also notice that the trim_offsets parameter is not available in the ByteLevel class. This is because in the rust backend, the ByteLevel struct is shared between the PreTokenizer, PostProcessor and Decoder, with defaults values for add_prefix_space, trim_offsets and use_regex set to true. Therefore you only need to consider the false cases when constructing pre_tokenizers/processors/decoders.ByteLevel.

tokenizer.json
"pre_tokenizer": {
  "type": "Sequence",
  "pretokenizers": [
    {
      "type": "Split",
      "pattern": {
        "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
      },
      "behavior": "Isolated",
      "invert": false
    },
    {
      "type": "ByteLevel",
      "add_prefix_space": false,
      "trim_offsets": true,
      "use_regex": false
    }
  ]
}
from tokenizers import pre_tokenizers, Regex

tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
    [
        pre_tokenizers.Split(
            pattern=Regex(
                r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
            ),
            behavior="isolated",
            invert=False,
        ),
        pre_tokenizers.ByteLevel(
            add_prefix_space=False,
            use_regex=False,
        ),
    ]
)
tokenizer.pre_tokenizer.pre_tokenize_str(test_str)
[('The', (0, 3)),
 ('Ġquick', (3, 9)),
 ('Ġbrown', (9, 15)),
 ('Ġfox', (15, 19)),
 ('Ġjumps', (19, 25)),
 ('Ġover', (25, 30)),
 ('Ġthe', (30, 34)),
 ('Ġlazy', (34, 39)),
 ('Ġdog', (39, 43)),
 ('ĠðŁĺĤ', (43, 45))]

Postprocessing

Postprocessing is used to apply a template to the tokens. This can be used to add special tokens to the beginning and end of the sequence or to apply a chatbot style template to the tokens. In Llama 3, the postprocessor is a sequence of two postprocessors: ByteLevel and TemplateProcessing. type_id can also be set using the : suffix which will set the type_id mask that is passed to the model. The Llama 3 models however do not use this feature, as it is not set in the model_input_names parameter of the tokenizer_config.json file. For the TemplateProcessing postprocessor, you need to define the token to token_id mapping for the special tokens used in the template.

tokenizer.json
"post_processor": {
  "type": "Sequence",
  "processors": [
    {
      "type": "ByteLevel",
      "add_prefix_space": true,
      "trim_offsets": false,
      "use_regex": true
    },
    {
      "type": "TemplateProcessing",
      "single": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        }
      ],
      "pair": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        },
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 1
          }
        },
        {
          "Sequence": {
            "id": "B",
            "type_id": 1
          }
        }
      ],
      "special_tokens": {
        "<|begin_of_text|>": {
          "id": "<|begin_of_text|>",
          "ids": [
            128000
          ],
          "tokens": [
            "<|begin_of_text|>"
          ]
        }
      }
    }
  ]
}
from tokenizers import processors

tokenizer.post_processor = processors.Sequence(
    [
        processors.ByteLevel(trim_offsets=False),
        processors.TemplateProcessing(
            single="<|begin_of_text|> $A",
            pair="<|begin_of_text|> $A <|begin_of_text|>:1 $B:1",
            special_tokens=[
                ("<|begin_of_text|>", non_special_vocab_size + 0),
            ],
        ),
    ]
)

Decoding

This decodes the token IDs back into text. As before the parameters are all set to true by default in the rust backend and are’t used by the decoder.

tokenizer.json
  "decoder": {
    "type": "ByteLevel",
    "add_prefix_space": true,
    "trim_offsets": true,
    "use_regex": true
  }
from tokenizers import decoders

tokenizer.decoder = decoders.ByteLevel()

Training

Llama 3 took 100K pre-trained tokens from OpenAIs tiktoken library and then trained it to add a further 28K tokens for better multi-lingual support. This explains why when you compare the Llama 3 and GPT-4 tokenizers, they are very similar. For a general purpose tokenizer it is likely more convenient to start with a pre-trained tokenizer. However, if you have a domain specific use case training your own tokenizer can significantly increase text compression and downstream model performance.

In this case we will train a BPE tokenizer on the WikiText-103 dataset as an example, however you can train on any text data you have available. In practice, if you find yourself training a specialized tokenizer on a large web dataset you may want to consider using the token vocabulary from a pre-trained tokenizer as your starting vocabulary.

Here we will set the initial alphabet to the ByteLevel alphabet to ensure every byte ends up in the vocabulary so that every byte can be tokenized without the need for the unk_token parameter.

After training the tokenizer we will add 256 special tokens and then save it to a tokenizer.json file.

We also run a quick test to validate that the tokenizer is working as expected.

import datasets

# Load and extract the text from the dataset
dataset = datasets.load_dataset(
    "Salesforce/wikitext",
    "wikitext-103-raw-v1",
    split="train",
)
corpus = dataset["text"]
from tokenizers import trainers

# Initialize trainer with target vocab size and set initial alphabet
initial_alphabet = pre_tokenizers.ByteLevel.alphabet()
trainer = trainers.BpeTrainer(
    vocab_size=non_special_vocab_size,
    initial_alphabet=initial_alphabet,
)

# Train tokenizer on dataset corpus
tokenizer.train_from_iterator(
    corpus,
    trainer=trainer,
    length=len(corpus),
)

# Add special tokens to end of vocab
special_tokens = [
    "<|begin_of_text|>",
    "<|end_of_text|>",
    "<|reserved_special_token_0|>",
    "<|reserved_special_token_1|>",
    "<|finetune_right_pad_id|>",
    "<|reserved_special_token_2|>",
    "<|start_header_id|>",
    "<|end_header_id|>",
    "<|eom_id|>",
    "<|eot_id|>",
    "<|python_tag|>",
]
special_tokens += [f"<|reserved_special_token_{i}|>" for i in range(3, 248)]
tokenizer.add_special_tokens(special_tokens)

# Save tokenizer to disk
tokenizer.save("tokenizer.json")

# Test encoding and decoding
encoded = tokenizer.encode(test_str)
print(f"encoded ids: {encoded.ids}")
print(f"encoded tokens: {encoded.tokens}")
decoded = tokenizer.decode(encoded.ids)
print(f"decoded: {decoded}")
assert decoded == test_str



encoded ids: [128000, 71343, 2428, 4857, 13756, 19572, 595, 260, 30920, 4722, 220, 172, 253, 246, 224]
encoded tokens: ['<|begin_of_text|>', 'The', 'Ġquick', 'Ġbrown', 'Ġfox', 'Ġjumps', 'Ġover', 'Ġthe', 'Ġlazy', 'Ġdog', 'Ġ', 'ð', 'Ł', 'ĺ', 'Ĥ']
decoded: The quick brown fox jumps over the lazy dog 😂

Importing to HF Transformers

If we then want to use this tokenizer in HF Transformers as a PreTrainedTokenizerFast instance, we can use AutoTokenizer to load the tokenizer. First you will need to create a copy of the tokenizer_config.json file, as in the introduction, to the same directory that you save the trained tokenizer.json file. You can then feed this directories path to AutoTokenizer.from_pretrained to load the tokenizer.

As you can see the generated token ids are the same as before and the tokenizer is able to recover the original text from the token ids.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(".")

encoded = tokenizer(test_str)
print(f"encoded ids: {encoded['input_ids']}")
decoded = tokenizer.decode(encoded["input_ids"], skip_special_tokens=True)
print(f"decoded: {decoded}")
assert decoded == test_str
encoded ids: [128000, 71343, 2428, 4857, 13756, 19572, 595, 260, 30920, 4722, 220, 172, 253, 246, 224]
decoded: The quick brown fox jumps over the lazy dog 😂

Comparing to Llama 3

Finally we will compare the tokenization of the Llama 3 tokenizer to the tokenization of our custom tokenizer. Since we have trained on different data the exact merged token vocabulary will be different and the tokenization will be different.

In this case we can see that although the token IDs are different, the tokenization is quite similar. Llama like had more emojis in the training data and so was able to merge some the emojis sub-bytes into single tokens.

Both were able to recover the original text from their respective token IDs.

our_tokenizer = AutoTokenizer.from_pretrained(".")
llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

our_encoded = tokenizer(test_str)
print(f"our encoded ids:   {our_encoded['input_ids']}")
llama_encoded = llama_tokenizer(test_str)
print(f"llama encoded ids: {llama_encoded['input_ids']}")

print(f"our encoded tokens:   {tokenizer.convert_ids_to_tokens(our_encoded['input_ids'])}")
print(f"llama encoded tokens: {llama_tokenizer.convert_ids_to_tokens(llama_encoded['input_ids'])}")

our_decoded = tokenizer.decode(our_encoded["input_ids"], skip_special_tokens=True)
print(f"our decoded:   {our_decoded}")
llama_decoded = llama_tokenizer.decode(
    llama_encoded["input_ids"], skip_special_tokens=True
)
print(f"llama decoded: {llama_decoded}")

assert our_decoded == llama_decoded
our encoded ids:   [128000, 71343, 2428, 4857, 13756, 19572, 595, 260, 30920, 4722, 220, 172, 253, 246, 224]
llama encoded ids: [128000, 791, 4062, 14198, 39935, 35308, 927, 279, 16053, 5679, 27623, 224]
our encoded tokens:   ['<|begin_of_text|>', 'The', 'Ġquick', 'Ġbrown', 'Ġfox', 'Ġjumps', 'Ġover', 'Ġthe', 'Ġlazy', 'Ġdog', 'Ġ', 'ð', 'Ł', 'ĺ', 'Ĥ']
llama encoded tokens: ['<|begin_of_text|>', 'The', 'Ġquick', 'Ġbrown', 'Ġfox', 'Ġjumps', 'Ġover', 'Ġthe', 'Ġlazy', 'Ġdog', 'ĠðŁĺ', 'Ĥ']
our decoded:   The quick brown fox jumps over the lazy dog 😂
llama decoded: The quick brown fox jumps over the lazy dog 😂

Fully Custom Tokenizer

Finally we will quickly go through an additional example of how to create a fully custom tokenizer.

The following is an example of a character level tokenizer with a selection of added special tokens. The code has the option of placing the special tokens at the beginning or end of the vocabulary. In this case we have placed them at the end of the vocabulary so that additional special tokens can easily be added in the future.

# Params
initial_alphabet = pre_tokenizers.ByteLevel.alphabet()
vocab_size = len(initial_alphabet)
special_tokens_at_start = False

# Special tokens
bos_token = "<|bos|>"
eos_token = "<|eos|>"
unk_token = "<|unk|>"
sep_token = "<|sep|>"
pad_token = "<|pad|>"
cls_token = "<|cls|>"
mask_token = "<|mask|>"
additional_special_tokens = []
special_tokens = [
    bos_token,
    eos_token,
    # unk_token,
    sep_token,
    # pad_token,
    # cls_token,
    mask_token,
] + additional_special_tokens
template_special_tokens = [
    (token, i if special_tokens_at_start else i + vocab_size)
    for i, token in enumerate(special_tokens)
    if token in [bos_token, eos_token, sep_token]
]

# Tokenizer
tokenizer = Tokenizer(models.BPE())

tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
    [
        pre_tokenizers.Split(
            pattern=Regex("."), behavior="isolated"
        ),  # Split by character
        pre_tokenizers.ByteLevel(
            add_prefix_space=False,
            use_regex=False,
        ),
    ]
)
tokenizer.post_processor = processors.Sequence(
    [
        processors.ByteLevel(trim_offsets=False),
        processors.TemplateProcessing(
            single="<|bos|> $A <|eos|>",
            pair="<|bos|> $A <|sep|> $B:1 <|eos|>:1",
            special_tokens=template_special_tokens,
        ),
    ]
)
tokenizer.decoder = decoders.ByteLevel()

trainer = trainers.BpeTrainer(
    vocab_size=vocab_size,
    initial_alphabet=initial_alphabet,
    special_tokens=(
        special_tokens if special_tokens_at_start else []
    ),  # Add special tokens to start of vocab
)

# Train tokenizer on initial alphabet so that no merges are performed
tokenizer.train_from_iterator(
    list(initial_alphabet),
    trainer=trainer,
    length=vocab_size,
)

if not special_tokens_at_start:
    # Add special tokens to end of vocab
    tokenizer.add_special_tokens(special_tokens)

print(f"vocab_size: {tokenizer.get_vocab_size()}")
print(f"vocab: {dict(sorted(tokenizer.get_vocab().items(), key=lambda item: item[1]))}")


encoded = tokenizer.encode(test_str)
print(f"encoded ids: {encoded.ids}")
print(f"encoded tokens: {encoded.tokens}")
decoded = tokenizer.decode(encoded.ids)
print(f"decoded: {decoded}")

tokenizer.save("tokenizer.json")
tokenizer = AutoTokenizer.from_pretrained(".")
encoded = tokenizer(test_str)
print(f"transformers encoded ids: {encoded['input_ids']}")
decoded = tokenizer.decode(encoded["input_ids"], skip_special_tokens=True)
print(f"transformers decoded: {decoded}")



vocab_size: 260
vocab: {'!': 0, '"': 1, '#': 2, '$': 3, '%': 4, '&': 5, "'": 6, '(': 7, ')': 8, '*': 9, '+': 10, ',': 11, '-': 12, '.': 13, '/': 14, '0': 15, '1': 16, '2': 17, '3': 18, '4': 19, '5': 20, '6': 21, '7': 22, '8': 23, '9': 24, ':': 25, ';': 26, '<': 27, '=': 28, '>': 29, '?': 30, '@': 31, 'A': 32, 'B': 33, 'C': 34, 'D': 35, 'E': 36, 'F': 37, 'G': 38, 'H': 39, 'I': 40, 'J': 41, 'K': 42, 'L': 43, 'M': 44, 'N': 45, 'O': 46, 'P': 47, 'Q': 48, 'R': 49, 'S': 50, 'T': 51, 'U': 52, 'V': 53, 'W': 54, 'X': 55, 'Y': 56, 'Z': 57, '[': 58, '\\': 59, ']': 60, '^': 61, '_': 62, '`': 63, 'a': 64, 'b': 65, 'c': 66, 'd': 67, 'e': 68, 'f': 69, 'g': 70, 'h': 71, 'i': 72, 'j': 73, 'k': 74, 'l': 75, 'm': 76, 'n': 77, 'o': 78, 'p': 79, 'q': 80, 'r': 81, 's': 82, 't': 83, 'u': 84, 'v': 85, 'w': 86, 'x': 87, 'y': 88, 'z': 89, '{': 90, '|': 91, '}': 92, '~': 93, '¡': 94, '¢': 95, '£': 96, '¤': 97, '¥': 98, '¦': 99, '§': 100, '¨': 101, '©': 102, 'ª': 103, '«': 104, '¬': 105, '®': 106, '¯': 107, '°': 108, '±': 109, '²': 110, '³': 111, '´': 112, 'µ': 113, '¶': 114, '·': 115, '¸': 116, '¹': 117, 'º': 118, '»': 119, '¼': 120, '½': 121, '¾': 122, '¿': 123, 'À': 124, 'Á': 125, 'Â': 126, 'Ã': 127, 'Ä': 128, 'Å': 129, 'Æ': 130, 'Ç': 131, 'È': 132, 'É': 133, 'Ê': 134, 'Ë': 135, 'Ì': 136, 'Í': 137, 'Î': 138, 'Ï': 139, 'Ð': 140, 'Ñ': 141, 'Ò': 142, 'Ó': 143, 'Ô': 144, 'Õ': 145, 'Ö': 146, '×': 147, 'Ø': 148, 'Ù': 149, 'Ú': 150, 'Û': 151, 'Ü': 152, 'Ý': 153, 'Þ': 154, 'ß': 155, 'à': 156, 'á': 157, 'â': 158, 'ã': 159, 'ä': 160, 'å': 161, 'æ': 162, 'ç': 163, 'è': 164, 'é': 165, 'ê': 166, 'ë': 167, 'ì': 168, 'í': 169, 'î': 170, 'ï': 171, 'ð': 172, 'ñ': 173, 'ò': 174, 'ó': 175, 'ô': 176, 'õ': 177, 'ö': 178, '÷': 179, 'ø': 180, 'ù': 181, 'ú': 182, 'û': 183, 'ü': 184, 'ý': 185, 'þ': 186, 'ÿ': 187, 'Ā': 188, 'ā': 189, 'Ă': 190, 'ă': 191, 'Ą': 192, 'ą': 193, 'Ć': 194, 'ć': 195, 'Ĉ': 196, 'ĉ': 197, 'Ċ': 198, 'ċ': 199, 'Č': 200, 'č': 201, 'Ď': 202, 'ď': 203, 'Đ': 204, 'đ': 205, 'Ē': 206, 'ē': 207, 'Ĕ': 208, 'ĕ': 209, 'Ė': 210, 'ė': 211, 'Ę': 212, 'ę': 213, 'Ě': 214, 'ě': 215, 'Ĝ': 216, 'ĝ': 217, 'Ğ': 218, 'ğ': 219, 'Ġ': 220, 'ġ': 221, 'Ģ': 222, 'ģ': 223, 'Ĥ': 224, 'ĥ': 225, 'Ħ': 226, 'ħ': 227, 'Ĩ': 228, 'ĩ': 229, 'Ī': 230, 'ī': 231, 'Ĭ': 232, 'ĭ': 233, 'Į': 234, 'į': 235, 'İ': 236, 'ı': 237, 'IJ': 238, 'ij': 239, 'Ĵ': 240, 'ĵ': 241, 'Ķ': 242, 'ķ': 243, 'ĸ': 244, 'Ĺ': 245, 'ĺ': 246, 'Ļ': 247, 'ļ': 248, 'Ľ': 249, 'ľ': 250, 'Ŀ': 251, 'ŀ': 252, 'Ł': 253, 'ł': 254, 'Ń': 255, '<|bos|>': 256, '<|eos|>': 257, '<|sep|>': 258, '<|mask|>': 259}
encoded ids: [256, 51, 71, 68, 220, 80, 84, 72, 66, 74, 220, 65, 81, 78, 86, 77, 220, 69, 78, 87, 220, 73, 84, 76, 79, 82, 220, 78, 85, 68, 81, 220, 83, 71, 68, 220, 75, 64, 89, 88, 220, 67, 78, 70, 220, 172, 253, 246, 224, 257]
encoded tokens: ['<|bos|>', 'T', 'h', 'e', 'Ġ', 'q', 'u', 'i', 'c', 'k', 'Ġ', 'b', 'r', 'o', 'w', 'n', 'Ġ', 'f', 'o', 'x', 'Ġ', 'j', 'u', 'm', 'p', 's', 'Ġ', 'o', 'v', 'e', 'r', 'Ġ', 't', 'h', 'e', 'Ġ', 'l', 'a', 'z', 'y', 'Ġ', 'd', 'o', 'g', 'Ġ', 'ð', 'Ł', 'ĺ', 'Ĥ', '<|eos|>']
decoded: The quick brown fox jumps over the lazy dog 😂
transformers encoded ids: [256, 51, 71, 68, 220, 80, 84, 72, 66, 74, 220, 65, 81, 78, 86, 77, 220, 69, 78, 87, 220, 73, 84, 76, 79, 82, 220, 78, 85, 68, 81, 220, 83, 71, 68, 220, 75, 64, 89, 88, 220, 67, 78, 70, 220, 172, 253, 246, 224, 257]
transformers decoded: The quick brown fox jumps over the lazy dog 😂

Bonus

For the previous example, if we are just trying to get started and don’t care about special tokens, we can skip using tokenizers and directly interpret the bytes of the text as a tensor.

import torch

tokenize = lambda text: torch.ByteTensor(bytearray(text, "utf-8"))
detokenize = lambda tensor: bytes(tensor).decode("utf-8")

tokenize(test_str), detokenize(tokenize(test_str))
(tensor([ 84, 104, 101,  32, 113, 117, 105,  99, 107,  32,  98, 114, 111, 119,
         110,  32, 102, 111, 120,  32, 106, 117, 109, 112, 115,  32, 111, 118,
         101, 114,  32, 116, 104, 101,  32, 108,  97, 122, 121,  32, 100, 111,
         103,  32, 240, 159, 152, 130], dtype=torch.uint8),
 'The quick brown fox jumps over the lazy dog 😂')