Writing·Blog

Fine-Tuning PaliGemma 3B, 10B, and 28B for Bounding-Box Generation

A practitioner's guide to fine-tuning Google's PaliGemma vision-language model for medical bounding-box detection — QLoRA across three model sizes, the SigLIP unfreeze question, loss-function tradeoffs, and the configs that actually worked.

Saianiruth M

Almost no good public writeup exists on fine-tuning PaliGemma for bounding-box generation. This is what we learned across the 3B, 10B, and 28B model sizes on a medical detection task — what worked, what didn't, and where the sweet spot lives.


PaliGemma is Google's vision-language model family that pairs a SigLIP visual encoder with a Gemma language decoder. Its party trick, relative to other VLMs, is that it generates bounding boxes as text tokens — special <locXXXX> codes that map to normalized image coordinates. You fine-tune the language model to emit four of these tokens per detection, parse them out at inference, and you have a detector. In principle.

In practice, almost everything between "load the model from Hugging Face" and "get usable boxes" turns out to require choices that aren't documented well anywhere. We spent a stretch of this past year fine-tuning all three model sizes — 3B, 10B, 28B — on a medical bounding-box task (shoulder-fracture detection on radiographs). This post is the practitioner's-guide writeup of what we converged on.


The bounding-box format

Before anything else, the format. PaliGemma represents a bounding box as four special tokens in the text output:

<loc0482><loc0103><loc0599><loc0226> Fracture;

The four tokens are normalized coordinates in (y_min, x_min, y_max, x_max) order, each in the range 0–1023. So <loc0482> means y = 482 / 1024 = 0.471 of the image height. To convert to pixel coordinates you scale by image height/width.

A few non-obvious things about this format:

  • The order is y_min, x_min, y_max, x_max — not the COCO-standard x_min, y_min, w, h. This trips up everyone the first time. If you parse with the wrong order, all your IoUs are zero and you spend a day wondering why training isn't converging.
  • The tokens are absolute pixel-grid quantizations, not learned embeddings of continuous coordinates. The model is doing a classification over 1,024 discrete bins per coordinate. Boundary precision is therefore ~1/1024 of the image dimension — fine for natural images, borderline for small medical pathologies.
  • The class label follows the boxes as a regular text token: Fracture or Normal. The model is generating an autoregressive sequence: four location tokens, then the class name, then end-of-sequence. The decoder is doing classification and localization in one shot.

The training data, in JSONL, looks like:

{"image": "study1.jpeg", "prefix": "Detect Fracture",
 "suffix": "<loc0482><loc0103><loc0599><loc0226> Fracture;"}
{"image": "study2.jpeg", "prefix": "Detect Fracture",
 "suffix": "Normal"}

The prefix is the user prompt; the suffix is what the model learns to emit. Empty-box cases (no fracture) just produce the class label Normal with no <loc> tokens. Multiple boxes are space-separated with the same pattern repeating.


Architecture refresher

Dataflow diagram showing input image at 224x224 feeding into a 27-layer SigLIP encoder, an optional custom task-attention head, the multimodal projector, and the Gemma language decoder, which emits four loc tokens followed by a class label.
PaliGemma dataflow for bounding-box generation. The dashed box marks the optional custom task-attention head — a 256-query cross-attention bottleneck we inserted between the encoder and projector.

The dataflow has four stages:

  1. Image → SigLIP encoder. 224×224 input, divided into 16×16 patches (so 14×14 = 196 patches), passed through a 27-layer Vision Transformer that emits a sequence of 1152-dim patch embeddings.
  2. Patch embeddings → multimodal projector. A single linear layer projects from the SigLIP embedding dimension (1152) into the language model's embedding space (2048 for the 3B Gemma decoder).
  3. Projected image tokens + text prompt → Gemma decoder. The image tokens are prepended to the prompt tokens, and the whole sequence is fed to a decoder-only Gemma language model that processes it autoregressively.
  4. Output tokens. The decoder emits the next tokens — for our use case, four <locXXXX> tokens followed by a class label.

Default fine-tuning recipe: only the language model is unfrozen. Both encoders (SigLIP and the projector) stay frozen. This is what Google's training scripts assume out of the box.

For medical imaging — where the SigLIP encoder hasn't seen X-rays during pre-training and the relevant visual features are subtle — this default is suboptimal. We'll come back to it.


QLoRA setup across model sizes

The three model sizes have very different memory profiles:

ModelTotal paramsTuning strategyQuantizationEffective trainable
PaliGemma 3B (paligemma2-3b-pt-224)~3BFull fine-tune of LM + projectorbfloat16~3B
PaliGemma 10B (paligemma2-10b-pt-224)~10BQLoRA on LM8-bit~50M (LoRA params)
PaliGemma 28B (paligemma2-28b-pt-224)~28BQLoRA on LM4-bit (nf4, double-quant)~80M (LoRA params)

QLoRA configuration we used on both 10B and 28B:

from peft import LoraConfig

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

A few notes:

  • Rank 16 is the standard QLoRA default and worked fine. We didn't sweep this rigorously; the literature suggests larger ranks help for harder tasks but hurt for easier ones. For medical bounding-box generation, 16 is in the right neighborhood.
  • target_modules restricted to attention projections. Including the MLP layers (up_proj, down_proj) increases trainable parameter count significantly but didn't materially improve our metrics. We kept the configuration minimal.
  • 8-bit on 10B, 4-bit on 28B. This isn't a choice so much as a hardware constraint — 28B doesn't fit on a single A100 80GB without 4-bit quantization. bnb_4bit_quant_type="nf4" and bnb_4bit_use_double_quant=True are non-negotiable for the 28B path.

Training hyperparameters we converged on for the 3B (which got full fine-tuning, no QLoRA):

Batch size:           2 (with gradient accumulation × 8 = effective 16)
Learning rate:        2e-5, cosine schedule, 2-step warmup
Optimizer:            AdamW (weight_decay=1e-6, betas=(0.9, 0.999))
Image size:           224 × 224
Precision:            bfloat16
Epochs:               planned 50, early-stopped around epoch 15

The QLoRA variants (10B, 28B) used similar hyperparameters with batch size 1 to fit memory.


The SigLIP unfreeze question

The single most impactful architectural decision we made was unfreezing part of the SigLIP encoder.

Default PaliGemma fine-tuning keeps SigLIP frozen entirely. The argument is reasonable: SigLIP has already learned strong general visual features; you don't want to overwrite them with a small medical dataset.

The counterargument is also reasonable: SigLIP has not learned chest-X-ray-specific features. The patches that contain fracture-relevant texture look nothing like the natural-image pre-training distribution. Letting the last few SigLIP layers adapt to the task seemed worth a try.

We tested three configurations:

SigLIP setting3B classification accuracy
Fully frozen61.3%
Last 5 layers unfrozen62.6%
Last 9 layers unfrozen64.5%

The last 5 layers gave most of the win at half the additional gradient memory cost. Unfreezing 9 layers added a small further improvement at meaningful additional cost. We settled on unfreezing the last 5 layers as the default.

This pattern shows up across the bounding-box results too — runs with the unfrozen-5 setting consistently outperformed fully-frozen-SigLIP runs by a few F1 points.


Loss functions

For the box-regression target, we used a combination of Focal Loss for the class assignment and CIoU Loss for the box-coordinate regression. The intuition:

  • Focal Loss handles the class imbalance between fracture / non-fracture examples (our training set was ~2:1 in favor of fracture cases) by down-weighting easy negatives. Standard for any imbalanced detection task.
  • CIoU (Complete IoU) Loss for the box-coordinate prediction. Plain L1/L2 loss on token IDs treats coordinate mis-prediction by 50 tokens as equally bad regardless of where the predicted box ends up — even a wildly wrong box gets the same penalty as a near-miss. CIoU penalizes by the geometric overlap quality, which is closer to what you actually care about.

We also tested Weighted Cross-Entropy (WCE) as a simpler alternative to Focal Loss. On the 3B model:

Loss configPrecisionRecallF1
Focal77.4%63.5%69.8%
WCE72.5%69.6%70.0%

Focal Loss pushes precision up at the cost of recall (the model becomes more conservative). WCE keeps precision and recall balanced. For a screening task — where missed fractures matter more than false alarms — WCE is the better choice. For a confirmatory task where false positives matter more, Focal Loss wins. We use WCE in our screening pipeline.


Custom attention head

One non-default modification that ended up helping: inserting a custom attention layer between the SigLIP encoder and the multimodal projector. The module is conceptually simple — 256 learnable queries that cross-attend over the SigLIP patch embeddings, producing a fixed-size summary that's then projected into the language model.

class TaskAttentionHead(nn.Module):
    def __init__(self, embed_dim=1152, num_queries=256):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(1, num_queries, embed_dim))
        self.attn = nn.MultiheadAttention(embed_dim, num_heads=1, batch_first=True)

    def forward(self, patch_embeddings):
        # patch_embeddings: (B, num_patches, embed_dim)
        B = patch_embeddings.size(0)
        queries = self.queries.expand(B, -1, -1)
        out, _ = self.attn(queries, patch_embeddings, patch_embeddings)
        return out  # (B, num_queries, embed_dim)

The motivation: the SigLIP encoder produces 196 patch embeddings per image, all weighted equally before the projector. For a small pathology that takes up maybe 10–20 of those patches, most of the visual context the language model sees is irrelevant. The custom attention head forces an early bottleneck that concentrates the visual representation on what matters.

On the 3B model with the last 5 SigLIP layers unfrozen, adding the attention head took F1 from 0.700 to 0.717 — a modest but real win, with explainability as a side benefit (you can visualize which patches each query attends to).


Preprocessing experiments

Two preprocessing additions we tested:

Canny edge detection (helped). Apply Canny edge detection to the input image, blend 30% edges with 70% original, feed to PaliGemma. The intuition: bone boundaries (where fractures appear) become more visually distinct. Result: F1 improved from baseline to 0.662 — comparable to the WCE-loss variant. A reasonable preprocessing add when you can spare the compute.

CLAHE — Contrast Limited Adaptive Histogram Equalization (broke training). CLAHE is standard preprocessing for chest X-rays in many CNN pipelines. We added it to the PaliGemma input. Result: F1 = 0.000. The model collapsed to predicting Normal for every input. The hypothesis: CLAHE significantly changes the global luminance distribution of the image, and the SigLIP encoder's pre-training expects natural-image luminance statistics. The encoder's representations under CLAHE were apparently far enough out-of-distribution that the language model couldn't learn anything useful from them.

The lesson: preprocessing that works for CNNs trained from scratch can break VLMs whose encoder was pre-trained on natural images. The further you push the input distribution from the encoder's pre-training, the more the encoder's output degrades. We dropped CLAHE entirely.


Results across configurations

Horizontal bar chart of F1 across eleven experimental configurations for fracture detection: 10B all-prefix Detect at 0.791 highlighted as best, 3B variants in the 0.69-0.76 range, ensembles around 0.74, and the CLAHE preprocessing run highlighted in red at F1 0.000.
F1 across the experiments. The 10B "all-prefix Detect" run wins, but 3B with the right configuration sits within a few points. CLAHE preprocessing collapses training — F1 = 0 — and is shown for contrast.

The full picture across our experimental configurations:

ModelConfigAccuracyPrecisionRecallF1
10Ball-prefix Detect78.4%83.3%75.3%0.791
3Bunfreeze 5, all-Detect77.6%92.2%64.4%0.758
Ensemblecls-3B + det-10B + DETR65.7%61.6%91.7%0.737
Ensemblecls-3B + det-10B68.6%79.7%54.0%0.736
3Bunfreeze 5 + attention head67.8%75.1%68.6%0.717
10Bmixed prefix (det + cls)66.3%64.8%78.3%0.709
3BWCE loss69.6%72.5%69.6%0.700
3Bfocal loss71.2%77.4%63.5%0.698
3Bmix 448px input67.4%92.2%49.5%0.645
3Bmix 224px input65.8%85.4%47.8%0.612
3BCLAHE preprocessing73.1%0%0%0.000

A few patterns worth pulling out:

  • "All-prefix Detect" beats "mixed prefix" consistently. The detail: in some experiments we trained the model on a mix of detection prompts ("Detect Fracture") and classification prompts ("Is there a fracture?"), expecting the model to learn both jointly. It learned both worse than focusing on one. One task per fine-tuning run is the safer default.
  • 3B with the right setup gets very close to 10B. F1 of 0.758 (3B unfreeze-5, all-Detect) vs 0.791 (10B same config) — a 3-point gap. At roughly 4× the inference cost for 10B over 3B, that gap might or might not be worth it.
  • The ensemble configurations have the highest recall (91.7%), which makes them attractive for screening despite lower headline F1. Combining a high-precision classifier (3B) with a high-recall detector (10B) gives you the best of both.
  • The 28B model isn't on the results table. We trained it but didn't run the full eval suite — at 4-bit quantization with our constraints, the wall-clock training time per epoch was too long to make systematic experiments practical. The qualitative impression was that it wasn't materially better than 10B at the precision we needed; the additional capacity wasn't earning its training cost.

Two-stage inference for box refinement

A practical issue: PaliGemma's box predictions drift on subtle pathologies. The model gets the region right but the exact corners are off by 10–30 pixels in either dimension. For coarse triage this is fine; for clinical reporting it isn't.

The workaround is a two-stage pipeline:

  1. Detection stage. PaliGemma proposes one or more boxes per image.
  2. Refinement stage. Each proposed box is cropped (with padding), fed to a small dedicated regression model (we use a lightweight ResNet-50 head trained on cropped box patches), and the corners are refined.

This recovers a few F1 points on the bounding-box-quality metric (mean IoU) without changing the recall characteristics. It also separates the "find it" and "draw the box" concerns into separately optimizable models. For our use case it was the right tradeoff. For a real-time application where two-pass inference doubles latency, it isn't.


The sweet spot

If we were standing up this system from scratch today with the same data:

  • For best F1, no compute constraint: PaliGemma 10B, "Detect"-only prefix, unfreeze last 5 SigLIP layers, WCE loss for screening (or Focal for confirmatory). Bounding-box F1 around 0.79.
  • For 95% of the F1 at under 25% of the inference cost: PaliGemma 3B with the same config. F1 around 0.76. For most production workloads, this is the rational choice. The cost-per-prediction at 3B is dramatically lower, training fits on smaller hardware, deployment is straightforward.
  • For a screening pipeline that prioritizes recall: ensemble of 3B classifier + 10B detector. Recall 91%+, accept the precision hit because human-in-the-loop catches the false positives.
  • Avoid: mixing classification and detection prompts in one fine-tuning run; using CLAHE preprocessing; 28B at the parameter count we tested unless you have specific evidence the extra capacity is earning its training cost.

The unsexy summary: PaliGemma 3B with the right configuration gets you most of the way there, and the right configuration involves unfreezing the last few SigLIP layers, picking one task per fine-tune, choosing the loss function for your operating point (Focal vs WCE), and being honest about whether your preprocessing is helping or destroying the encoder's pre-trained representations.


What I'd do differently

A few things, if starting over:

  • Sweep image size more carefully early. 224 vs 448 affects everything downstream — the SigLIP forward cost, the patch count, the effective resolution of bounding-box predictions. We mostly stuck with 224; 448 might have been a better default for medical detection where small pathologies matter.
  • Quantify the boundary precision ceiling explicitly. The <locXXXX> token grid quantizes coordinates to 1/1024 of the image dimension. For 224×224 inputs, that's a 0.22-pixel grid resolution — fine. For real medical images at native resolution (often 2,000+ pixels per side), the effective bounding-box edge precision is 2+ pixels per token. That's enough to matter for small pathologies. Worth understanding before fine-tuning, not after.
  • Treat ensemble configurations as first-class earlier. The classifier-detector ensemble was a late addition, but it ended up being a better fit for our actual deployment use case than any single model. Designing for it from the start (rather than treating ensembles as a "stack the models we have" afterthought) would have produced cleaner per-stage models.

Closing

There's almost no good public material on fine-tuning PaliGemma for bounding-box generation, and even less on it for medical-imaging targets. The Google docs cover the format and the basic fine-tuning loop; everything past that — the SigLIP unfreeze question, the loss choices, the preprocessing-can-break-the-encoder lesson, the model-size sweet spot — is uncharted in any public writeup I've found.

If you're working on the same problem, here are the takeaways in one paragraph: use PaliGemma 3B as your starting point, unfreeze the last 5 SigLIP layers, train with one prompt template per fine-tune, pick WCE or Focal based on whether you care more about recall or precision, and skip CLAHE. Get the baseline F1 measured before reaching for 10B. If 10B genuinely earns its cost on your task, switch; otherwise stay on 3B and put the saved compute into ensemble diversity or two-stage refinement. That's the playbook.


Part of an ongoing series on production medical imaging. The companion year-one reflection is here; the shoulder-fracture ensemble paper repackage is here; the Gemini-vs-CNN clinical-QC lab note is here. If you're working through the same fine-tuning question, reach out.