[discrete diffusion] Add DiffusionGemma pipeline and schedulers#13986
[discrete diffusion] Add DiffusionGemma pipeline and schedulers#13986kashif wants to merge 42 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Looking great! A couple questions from quick skimming
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! i left a few comments
I reviewed this through the lens of diffuser convention/style. If some of these choices are intentional to keep things familiar for Transformers users, let me know, and we can figure out the right balance together
| def __call__( | ||
| self, | ||
| prompt: str | list[str] | None = None, | ||
| messages: list[dict[str, str]] | None = None, |
There was a problem hiding this comment.
I think between prompt and messages, we only need accept prompt since it's a really cheap into messages
it's just this, no?
messages = [{"role": "user", "content": prompt}]There was a problem hiding this comment.
Makes sense. The one wrinkle is image prompts, which we pass through messages today, so I'll fold the prompt/messages simplification into the image input rework so single-image and text both stay clean. Coming in a follow-up.
There was a problem hiding this comment.
Made prompt the primary input and dropped the tokenized intermediates. Kept messages for raw multi-turn/multimodal conversations (per the thread below with zucchini), and added a raw image arg for the simple prompt+image case, so it is all raw inputs now.
Adds optional Gibbs corrector sweeps after each predictor step for uniform diffusion, recovering the LOO denoiser in closed form so it works on the released checkpoint with no retraining. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The denoiser is a Transformers model, so adapters (LoRA, DoRA, ...) load through its native PEFT integration rather than the diffusers LoRA loader. Also dispatch the predictor-corrector by scheduler capability instead of class. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left some design comments :).
|
thanks @dg845 fixing |
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
…ma.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
01d4990 to
1f46257
Compare
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks for working on it, i think the overall latency and quality matches model released in transformers!
|
When I tried out the example script: import torch
from transformers import AutoProcessor, DiffusionGemmaForBlockDiffusion
from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline
model_id = "google/diffusiongemma-26B-A4B-it"
dtype = torch.bfloat16
model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=dtype, device_map="auto")
# model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=dtype)
processor = AutoProcessor.from_pretrained(model_id)
scheduler = BlockRefinementScheduler()
pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor)
pipe.model.model.decoder = torch.compile(pipe.model.model.decoder, mode="reduce-overhead", fullgraph=True)
# pipe.to("cuda")
output = pipe(
prompt="Why is the sky blue?",
gen_length=256,
num_inference_steps=48,
cache_implementation="static",
generator=torch.Generator("cuda").manual_seed(42),
)
print(output.texts[0])I get the following torch._dynamo.exc.Unsupported: Skip calling `torch.compiler.disable()`d function
Explanation: Skip calling function `<function AlignDevicesHook.pre_forward at 0x73a280162950>` since it was wrapped with `torch.compiler.disable` (reason: None)
Hint: Remove the `torch.compiler.disable` callThe script works if I either don't EDIT: the error can also be mitigated by using |
|
|
||
| from diffusers import BlockRefinementScheduler, DiffusionGemmaPipeline | ||
|
|
||
| model_id = "google/diffusiongemma-26B-A4B-it" |
There was a problem hiding this comment.
@yiyixuxu @dg845 like the rest of our diffusers checkpoints repositories, where we have pipeline components coming from a different repo, could we make this pipeline a diffusers-style checkpoint with something like:
model = DiffusionGemmaForBlockDiffusion.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
scheduler = BlockRefinementScheduler()
pipe = DiffusionGemmaPipeline(model=model, scheduler=scheduler, processor=processor)
pipe.save_pretrained(...)| # Encode the tokens not yet in the cache (the whole prompt on the first block, the last committed canvas | ||
| # afterwards), so the decoder reuses the encoder KV cache instead of re-encoding the full sequence. | ||
| cached_len = past_key_values.get_seq_length() | ||
| torch.compiler.cudagraph_mark_step_begin() |
There was a problem hiding this comment.
Why do we need this? We usually don't use CUDAgraph markers like this explicitly in our pipelines.
|
@claude could you review this PR and also comment on the usage of the torch ops usage pattern from the lens of efficiency? |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
…ma.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
| canvas = scheduler_output.prev_sample | ||
| # Self-condition on the temperature-shaped logits when the scheduler shapes them (the reference | ||
| # sampler), else the raw logits. | ||
| self_conditioning_logits = scheduler_output.get("pred_logits", logits) |
There was a problem hiding this comment.
If I understand correctly, the self-conditioning logits are the temperature-shaped logits for EntropyBoundScheduler (since it defines a pred_logits output field), but the raw logits for BlockRefinementScheduler and DiscreteDDIMScheduler (which don't define pred_logits). Is this intentional? If pred_logits is a property of the sampling algorithm (which I'm not sure about), I think it would be more clear for BlockRefinementScheduler and DiscreteDDIMScheduler to also output a pred_logits field (even if it's just the raw logits).
| # Anneal the temperature from `t_max` to `t_min` over the schedule and scale the logits by it once, so the | ||
| # acceptance entropy is measured on the same distribution the candidates are drawn from. | ||
| fraction = (self.num_inference_steps - int(timestep)) / self.num_inference_steps | ||
| temperature = self.config.t_min + (self.config.t_max - self.config.t_min) * fraction |
There was a problem hiding this comment.
Should the last temperature on the annealing schedule be self.config.t_min? I believe right now the last fraction is 1 / self.num_inference_steps, so the temperature at the last timestep self.num_inference_steps - 1 is strictly greater than t_min.
Adds a DiffusionGemma block-diffusion pipeline, alongside the schedulers already on this branch (discrete DDIM, entropy bound, and a uniform mode for block refinement).
DiffusionGemma is an encoder-decoder block-diffusion model: the encoder reads the prompt into a KV cache and the decoder denoises a fixed-size canvas by cross-attending to it. The pipeline runs the outer canvas loop and the inner denoising loop, sampling candidates each step, committing the most confident ones via
BlockRefinementSchedulerin uniform corruption mode, and renoising the rest. Structure mirrors the LLaDA2 and dflash (#13699) pipelines.The model itself lives in transformers as
DiffusionGemmaForBlockDiffusion(released in 5.12.0).Tested:
Quality on the full
google/diffusiongemma-26B-A4B-itcheckpoint still needs a GPU run.