5 min read
Matryoshka Representation Learning for Medical Coding

ICD coding is one of those problems that sounds administrative until you look at it closely. Every hospital encounter ends with a coder translating the clinical narrative into a set of ICD codes — standardised identifiers for diagnoses and procedures. The codes drive billing, epidemiological statistics, and research datasets. Miscoding costs money and distorts health data at scale.

It’s also a retrieval problem disguised as a classification problem. There are tens of thousands of ICD codes. Training a classifier with a head over all of them is brittle and doesn’t generalise to new or rare codes. The better framing: learn a shared embedding space where clinical text and code descriptions are geometrically close when semantically related, then do nearest-neighbour retrieval at inference time.

That’s a contrastive bi-encoder. The question is what kind of embeddings to learn.

The fixed-dimension trap

Standard dense retrieval learns a single embedding dimension — say, 768d for a BERT-class model. Every query and every document lives in that 768-dimensional space. At inference you compute dot products and rank.

The problem for production medical systems: you don’t always have the same latency budget. A fast first-pass filter over 50,000 ICD codes needs to be cheap. A precise re-ranker over 20 candidates can afford to be expensive. With fixed-dimension embeddings you’re stuck at one operating point — you can’t trade off speed against accuracy without training a separate model.

The naive fix is to train multiple models at different dimensions. That’s expensive, operationally messy, and wasteful: a 64d model and a 768d model trained separately share no parameters even though the 64d task is strictly easier.

What Matryoshka embeddings actually are

Matryoshka Representation Learning (Kusupati et al., 2022) trains a single model to produce embeddings that are simultaneously meaningful at multiple nested prefix dimensions — in my case [64, 128, 256, 768].

The key constraint: the first 64 dimensions of a 768d embedding must themselves form a useful 64d representation. The first 128 must form a useful 128d representation. And so on. Like the nested dolls — each prefix is a complete, usable representation at its own scale.

The loss function makes this happen by computing the contrastive objective at every nesting dimension in a single forward pass and summing the weighted losses:

LMRL=mMwmLcontrastive(f(x)1:m,f(y)1:m)\mathcal{L}_{MRL} = \sum_{m \in M} w_m \cdot \mathcal{L}_{contrastive}(f(x)_{1:m}, f(y)_{1:m})

where M={64,128,256,768}M = \{64, 128, 256, 768\} and f(x)1:mf(x)_{1:m} is the first mm dimensions of the embedding.

This is a hard constraint on the encoder: it must pack the most semantically discriminative information into the early dimensions. Coarser structure first, finer detail later. At inference you truncate to any prefix and get a valid representation — no retraining, no separate models.

Why this fits ICD coding specifically

The ICD taxonomy has natural granularity levels. At the top: broad disease categories (diseases of the respiratory system). At the bottom: specific codes with clinical nuance (J18.1 — lobar pneumonia, unspecified organism). A 64d embedding can distinguish respiratory from cardiovascular; a 768d embedding distinguishes pneumonia subtypes.

This maps cleanly onto a retrieval pipeline:

  1. Fast 64d pass → filter from 50K codes to 200 candidates
  2. 256d re-rank → narrow to 20
  3. 768d final → rank the shortlist

One model. Three operating points. No retraining between them.

The encoder choice

Standard BERT has a 512-token context limit. Radiology reports can run long, and truncating them loses findings. BioClinical-ModernBERT brings ModernBERT’s architecture — Flash Attention 2, 8192-token context, significantly better on BERT benchmarks — to a model pretrained on clinical text (MIMIC-III). Longer context without losing domain grounding.

Label-Aware Attention

Mean-pooling collapses the full sequence into one vector. For multi-label ICD coding, different sentences in the report are evidence for different codes. A finding about atelectasis is in one sentence; a finding about pleural effusion is in another.

Label-Aware Attention adds a per-label attention mechanism that learns to focus on the tokens most relevant to each candidate code. Instead of pooling everything, it extracts targeted evidence. In practice this helped most on reports with multiple findings — exactly the hard cases.

What I learned

The most useful insight from this project wasn’t technical — it was about problem framing. The fixed-dimension assumption is so embedded in how dense retrieval is typically discussed that it’s easy to miss it as a choice rather than a constraint. MRL makes it a choice.

The geometric intuition also matters: you’re not just adding auxiliary losses, you’re enforcing a hierarchical structure on the embedding manifold. The model has to learn that coarser distinctions are more fundamental than finer ones — which is true in medicine, in language, and in most structured domains.

That’s a useful prior to build into a model, not just a training trick.

training...