Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma3 adding new tokens <image_soft_token> has been added accidentally #37011

Closed
4 tasks
Serzhanov opened this issue Mar 26, 2025 · 7 comments
Closed
4 tasks
Labels

Comments

@Serzhanov
Copy link

Serzhanov commented Mar 26, 2025

System Info

Hello,
When adding custom tokens to the gemma_3b_1_it tokenizer, an unexpected token (<image_soft_token>) appears in the model's embedding matrix — even though it was not explicitly added.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Model :

from transformers import AutoTokenizer, BitsAndBytesConfig, Gemma3ForCausalLM
import torch

model_id = "google/gemma-3-1b-it"
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = Gemma3ForCausalLM.from_pretrained(
    model_id, quantization_config=quantization_config,token='#your token'
)
tokenizer = AutoTokenizer.from_pretrained(model_id,token='#your token')

To Reproduce:

old_input_embedding  = model.get_input_embeddings().weight.detach().clone()
old_output_embedding = model.get_output_embeddings().weight.detach().clone()
old_input_length     = old_input_embedding.shape[0]

new_tokens = ["<CHARACTER_1>", "<THINKING>", "<SCRATCH_PAD>"]
old_tokenizer_length = len(tokenizer)
tokenizer.add_tokens(new_tokens)

model.resize_token_embeddings(len(tokenizer))

new_input_embedding = model.get_input_embeddings().weight.detach()
new_output_embedding = model.get_output_embeddings().weight.detach()

num_added = new_input_embedding.shape[0] - old_input_length
if num_added > 0:
    new_rows = new_input_embedding[-num_added:]  # New token embeddings
    new_token_ids = range(old_input_length, old_input_length + num_added)
    new_tokens_by_embedding = tokenizer.convert_ids_to_tokens(list(new_token_ids))
    
    print(f" Based on embeddings, {num_added} new token(s) were added:")
    for token_id, token in zip(new_token_ids, new_tokens_by_embedding):
        print(f" - Token ID {token_id}: '{token}'")
else:
    print("No new embeddings were added (embedding size unchanged)")

Output :


 Based on embeddings, 4 new token(s) were added:
 - Token ID 262144: '<image_soft_token>'
 - Token ID 262145: '<CHARACTER_1>'
 - Token ID 262146: '<THINKING>'
 - Token ID 262147: '<SCRATCH_PAD>'

Expected behavior

Based on embeddings, 3 new token(s) were added:

 - Token ID 262145: '<CHARACTER_1>'
 - Token ID 262146: '<THINKING>'
 - Token ID 262147: '<SCRATCH_PAD>'
@Serzhanov Serzhanov added the bug label Mar 26, 2025
@Rocketknight1
Copy link
Member

cc @ArthurZucker @itazap for tokenizers

@devdevgoat
Copy link

This is happening when using mlx_lm.lora as well even without adding custom tokens. The resulting fine tune outputs a vocab size of 262145, as opposed to 262144 specified in the config.json. When attempting to use the resulting adapter, ollama will fail validation.

@itazap
Copy link
Collaborator

itazap commented Apr 3, 2025

Hello @Serzhanov, took a deeper look and the <image_soft_token> is present even before adding tokens. You can verify with:

from transformers import AutoTokenizer, BitsAndBytesConfig, Gemma3ForCausalLM
import torch

model_id = "google/gemma-3-1b-it"
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = Gemma3ForCausalLM.from_pretrained(
    model_id, quantization_config=quantization_config,token='#your token'
)
tokenizer = AutoTokenizer.from_pretrained(model_id,token='#your token')
print(len(tokenizer))
print(model.vocab_size)
print(f"'<image_soft_token>' in tokenizer vocab: {'<image_soft_token>' in tokenizer.vocab}")

The behavior you're seeing in terms of the vocab size being +4 is because len(tokenizer) = 262145 but model.vocab_size == 262144, so when you add + num_added it goes up to 262148( instead of 262147).

It's important to note that the model.vocab_size is not always reflective of the actual vocabulary size. For example, sometimes the embeddings are padded to be a multiple of 32/64/128 to optimize performance with tensor accelerators. It's best to rely on len(tokenizer) for the size! 👍

@itazap
Copy link
Collaborator

itazap commented Apr 3, 2025

@devdevgoat I'm not familiar about the validation issue you're describing, do you have a code snippet to reproduce the failure?

@Serzhanov
Copy link
Author

@itazap Hello, thank you for the clarification. I see your point — that makes sense. I’m wondering though, do you think this behavior should raise a warning? In most models and previous versions of Gemma, like gemma-2b, the vocab_size and len(tokenizer) are the same, so this discrepancy might catch some people off guard.

@itazap
Copy link
Collaborator

itazap commented Apr 4, 2025

This is true for a lot of models (Bloom, Phi, Gemma, etc.), and so I agree with you that it's important to have a good explanation as to how a model's vocab_size is calculated and modified for performance reasons - I will look into adding this info to the Docs and DocStrings!

@Serzhanov
Copy link
Author

@itazap Great , I can close the issue now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants
@devdevgoat @Rocketknight1 @itazap @Serzhanov and others