Skip to content

Conversation

amorehead
Copy link

@amorehead amorehead commented Sep 18, 2025

  • Adds a Jupyter notebook containing an example of a multimodal flow matching model built using the flow_matching package's components.
  • As shared utilities, a new flow_matching/utils/multimodal.py module and flow_matching/solver/multimodal_solver.py are introduced to construct instances of the new multimodal Flow class and sample from it.
  • The architecture of this example's model is a bit different than that of other examples in that, for each modality, dedicated input and output MLPs are trained, while a common Transformer trunk is shared across all modalities. Tweaking the training of this model could likely lead to improved results for this example.
  • As secondary validation, I've also trained a model on a real-world dataset using the multimodality utilities included in this PR.

@meta-cla meta-cla bot added the cla signed label Sep 18, 2025
@itaigat
Copy link
Contributor

itaigat commented Sep 19, 2025

Thanks @amorehead! Can you please give a more detailed explanation on the multimodal model?

@amorehead
Copy link
Author

amorehead commented Sep 19, 2025

Hi, @itaigat. The multimodal model consists of similar MLPs as used for the discrete and continuous (checkerboard) flow matching examples. However, now they are used for generating both input and output embeddings of each respective modality, and they each share a common nn.TransformerEncoder module that is applied to their (concatenated) input embeddings. By sharing these Transformer weights in the middle of this new TransformerModel, common features can be learned across all modalities (at least that's the intuition).

@amorehead
Copy link
Author

amorehead commented Sep 24, 2025

@itaigat, I've added a new multimodal_solver.py that natively supports solving a variable number of discrete and continuous modalities' sampling trajectories using a single (multimodal) model. This has improved the performance of the multimodal model presented in the new 2d_multimodal_flow_matching.ipynb example notebook (though its sampling visualizations could likely be improved by using a larger sampling batch size - I'm currently limited to running on a MacBook, so apologies in advance). I've also added two new unit tests for MultimodalSolver and the multimodal Flow class, respectively. Let me know if you have any questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants