← Writing

Why Mistral and Devstral models drop their spaces on Apple Silicon

If you’ve run a recent Mistral model — Devstral-Small-2, or anything on the tekken v13 tokenizer — through mlx-lm’s server on a Mac, you may have seen output like this:

Hello!ĠHowĠcanĠIĠassistĠyouĠtoday?

Every space replaced by a literal Ġ. The model is fine. The weights are fine. The bug is in detokenization — the step that turns token IDs back into text — and it’s specific to how this family of tokenizers is built. Here’s the full trace, because the root cause is a genuinely interesting collision between two tokenizer conventions, and the fix is one routing decision.

I tracked this down and shipped the fix to ml-explore/mlx-lm (PR #1329).

Before: Mistral/Devstral on Apple Silicon prints Ġ instead of spaces. After: clean output.
Before and after the fix — the byte-level space marker Ġ leaking into the output, then correctly decoded back to spaces.

TL;DR

Reproducing it without downloading 24B weights

The first useful move when debugging on-device model behavior: you almost never need the weights. A detokenization bug lives entirely in the tokenizer, which is a few megabytes. So you can reproduce a bug from a 24-billion-parameter model on a laptop in seconds:

from huggingface_hub import snapshot_download
from mlx_lm.utils import load_tokenizer

# tokenizer files only — no model weights
path = snapshot_download(
    "mlx-community/Devstral-Small-2-24B-Instruct-2512-bf16",
    allow_patterns=["*.json", "*.model", "*tokenizer*", "*.jinja"],
)
tok = load_tokenizer(path)
print(type(tok.detokenizer).__name__)   # SPMStreamingDetokenizer  <- the smoking gun

text = "Hello! How can I assist you today?"
ids = tok.encode(text, add_special_tokens=False)

d = tok.detokenizer
d.reset()
for t in ids:
    d.add_token(t)
d.finalize()
print(repr(d.text))   # 'Hello!ĠHowĠcanĠIĠassistĠyouĠtoday?'

Two things jump out: the detokenizer chosen is SPMStreamingDetokenizer, and the output keeps the Ġ.

The root cause: a hybrid tokenizer

mlx-lm chooses a streaming detokenizer in tokenizer_utils.load() by inspecting the decoder block of tokenizer.json. Three shapes are recognised:

Now look at what the tekken v13 tokenizer actually declares:

// decoder  — looks exactly like SentencePiece
{"type": "Sequence", "decoders": [
  {"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
  {"type": "ByteFallback"}, {"type": "Fuse"},
  {"type": "Strip", "content": " ", "start": 1, "stop": 0}
]}

// pre_tokenizer — but the vocabulary is byte-level!
{"type": "Sequence", "pretokenizers": [
  {"type": "ByteLevel", "add_prefix_space": false, ...},
  {"type": "Metaspace", "replacement": "▁", ...}
]}

This tokenizer is a chimera. Its decoder matches the SentencePiece pattern exactly, so mlx-lm routes it to the SPM detokenizer. But its pre-tokenizer is ByteLevel, which means the vocabulary is encoded with GPT-2 byte markers — the tokens literally contain Ġ, ĠĠ, Ġt, and so on.

The SPM detokenizer’s whole job is to replace the SentencePiece space marker with a space. It has never heard of Ġ. So Ġ sails straight through, untouched, into your output.

The twist: even Hugging Face gets this one wrong

While verifying, I checked what transformers’ own tokenizer.decode() does with these IDs. It also mangles them — emitting Ġ for English and raw byte mojibake (ãģĵãĤĵ…) for Japanese. The interesting part: mlx-lm’s byte-level detokenizer decodes this tokenizer correctly, where the reference transformers decode does not. The byte-level path knows how to map Ġ → space and reassemble multi-byte UTF-8. So the fix isn’t “match what HF does” — it’s “use the detokenizer that’s actually right for a byte-level vocabulary.”

The fix

The decoder field lied; the pre-tokenizer told the truth. A ByteLevel pre-tokenizer is an authoritative signal that the vocabulary is byte-level and therefore needs the BPE detokenizer — no matter what the decoder looks like:

def _has_byte_level_pretokenizer(tokenizer_content):
    def _is_byte_level(node):
        return isinstance(node, dict) and node.get("type") == "ByteLevel"
    pre = tokenizer_content.get("pre_tokenizer")
    if _is_byte_level(pre):
        return True
    if isinstance(pre, dict) and pre.get("type") == "Sequence":
        return any(_is_byte_level(p) for p in pre.get("pretokenizers", []))
    return False

And in the routing:

if _is_bpe_decoder(decoder) or _has_byte_level_pretokenizer(tokenizer_content):
    detokenizer_class = BPEStreamingDetokenizer      # byte-level vocab wins
elif _is_spm_decoder(decoder):
    detokenizer_class = SPMStreamingDetokenizer
elif _is_spm_decoder_no_space(decoder):
    detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)

Result:

before: Hello!ĠHowĠcanĠIĠassistĠyouĠtoday?   (SPMStreamingDetokenizer)
after:  Hello! How can I assist you today?    (BPEStreamingDetokenizer)

Making sure it doesn’t break the models that already worked

The one real risk in a routing change like this is regressing tokenizers that were fine. A genuine SentencePiece tokenizer must keep using the SPM detokenizer. The key insight that makes the fix safe: real SPM tokenizers don’t have a ByteLevel pre-tokenizer. I confirmed that against the detokenizer-class assertions already in the test suite, downloading tokenizer files only:

Only the broken case changes. The PR adds a regression test for the tekken case (asserting Ġ-free output) plus a unit test for the new helper.

Why this matters if you’re shipping LLMs on-device

This is the kind of bug that doesn’t show up in a benchmark and doesn’t show up in the cloud — it shows up the moment a specific model meets a specific runtime on someone’s laptop. On-device and edge deployment is full of these seams: tokenizer conventions, quantization formats, memory limits, sliding-window caches. The model “works” everywhere except the one stack your user is actually running.

Debugging them well is mostly discipline: reproduce with the smallest possible artifact (here, tokenizer-only, no weights), trust the data over the labels (the decoder said SPM; the vocabulary said byte-level), and verify you didn’t break the cases that worked. The fix was three lines. Finding the right three lines was the job.


I’m Prasad Khake — I build AI-native products and make LLMs run well on real, on-device, and Apple-Silicon hardware. If you’re shipping local or private LLMs and hitting exactly these kinds of seams, get in touch. The fix above is PR #1329 in ml-explore/mlx-lm.

Subscribe