JAX実装: F2D2(Flowモデルにおける高速尤度評価とサンプリングのためのJoint Distillation)

AI/ML

概要

F2D2-Jaxは、論文「F2D2: Joint Distillation for Fast Likelihood Evaluation and Sampling in Flow-based Models」の手法をJAXで実装したリポジトリです。Flowベースの教師モデルから、尤度評価とサンプリングを高速化するために共同蒸留(joint distillation)を行うアプローチを扱っています。元の手法や関連する「flow maps / consistency model」に基づいた実験スクリプトやノートブック、簡易デモ(checkerboardなど)が含まれており、研究目的での再現や改良、JAXの高速計算機能を利用したプロトタイプ開発に適しています。

GitHub

リポジトリの統計情報

  • スター数: 2
  • フォーク数: 1
  • ウォッチャー数: 2
  • コミット数: 2
  • ファイル数: 7
  • メインの言語: Jupyter Notebook

主な特徴

  • F2D2手法のJAXによるプロトタイプ実装(論文実装の簡易再現)
  • ノートブックベースで実験を再現しやすい構成(可視化、デモ含む)
  • 公式 flow-maps 実装をベースにした派生実装で、consistency modelやflow map関連の手法と親和性が高い
  • 軽量なコードベースで研究や教育用に利用しやすい

技術的なポイント

本実装はJAXの利点(関数変換、JITコンパイル、自動微分、乱数管理など)を活かして、Flowベース教師モデルからの蒸留プロセスを効率化することを狙っています。F2D2の本質は「尤度評価器」と「サンプラー」を同時に学習(joint distillation)する点にあり、教師モデルの高精度な推定を学生モデルに転移することで推論時の計算コストを下げます。JAX実装では、データ並列やvmapを使ったバッチ処理、JITによる高速化、数値安定化のための工夫(ログ確率の扱い、逆伝播でのクリッピングなど)が想定されます。ノートブックは主に小規模データセットや合成データ(checkerboard)での動作確認、学習曲線の可視化、サンプリング品質の比較を行えるよう設計されており、オリジナル論文で提案された評価指標や蒸留損失の構成要素を検証するのに役立ちます。また、元実装(flow-maps)との比較・改変がしやすいコード構成になっており、研究者がモデル構造や蒸留戦略を試行錯誤する際の出発点として有用です。

プロジェクトの構成

主要なファイルとディレクトリ:

  • .gitignore: file
  • LICENSE: file
  • README.md: file
  • checkerboard: dir
  • notebooks: dir

…他 2 ファイル

まとめ

JAXで手早くF2D2を試せる研究向けの実装です。

リポジトリ情報:

READMEの抜粋:

F2D2-Jax

JAX implementation of F2D2:Joint Distillation for Fast Likelihood Evaluation and Sampling in Flow-based Models https://arxiv.org/abs/2512.02636

Modified from Official repository for “How to build a consistency model: Learning flow maps via self-distillation” (NeurIPS 2025). https://arxiv.org/abs/2505.18825 by Nicholas M. Boffi (CMU), Michael Albergo (Harvard), and Eric Vanden-Eijnden (Courant Institute of Mathematical Sciences, Capit…