diff --git a/README.md b/README.md index 0e7b5b7..903cf32 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ pip install -e . #### I want to train a Flow Matching model, where can I find the training code? -We provide [training examples](examples). Under this folder, you can find synthetic data for [continuous](examples/2d_flow_matching.ipynb), [discrete](examples/2d_discrete_flow_matching.ipynb), and [Riemannian](examples/2d_riemannian_flow_matching_flat_torus.ipynb) Flow Matching. We also provide full training [examples](examples/image) (continuous and discrete) on CIFAR10 and face-blurred ImageNet, and a scalable discrete Flow Matching example for [text modeling](examples/text). +We provide [training examples](examples). Under this folder, you can find synthetic data for [continuous](examples/2d_flow_matching.ipynb), [discrete](examples/2d_discrete_flow_matching.ipynb), [multimodal](examples/2d_multimodal_flow_matching.ipynb), and [Riemannian](examples/2d_riemannian_flow_matching_flat_torus.ipynb) Flow Matching. We also provide full training [examples](examples/image) (continuous and discrete) on CIFAR10 and face-blurred ImageNet, and a scalable discrete Flow Matching example for [text modeling](examples/text). #### Do you release pre-trained models? diff --git a/docs/Makefile b/docs/Makefile index 05c2bb7..ef4da3c 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -13,6 +13,7 @@ ROOT_DIR:=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) links: mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/standalone_flow_matching.ipynb source/notebooks/standalone_flow_matching.ipynb mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_discrete_flow_matching.ipynb source/notebooks/2d_discrete_flow_matching.ipynb + mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_multimodal_flow_matching.ipynb source/notebooks/2d_multimodal_flow_matching.ipynb mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_riemannian_flow_matching_flat_torus.ipynb source/notebooks/2d_riemannian_flow_matching_flat_torus.ipynb mkdir -p source/notebooks && ln -sfn $(ROOT_DIR)/../examples/2d_riemannian_flow_matching_sphere.ipynb source/notebooks/2d_riemannian_flow_matching_sphere.ipynb ln -sfn $(ROOT_DIR)/../assets/teaser.png source/_images/teaser.png diff --git a/docs/source/_images/multimodal.png b/docs/source/_images/multimodal.png new file mode 100644 index 0000000..a475d87 Binary files /dev/null and b/docs/source/_images/multimodal.png differ diff --git a/docs/source/dummy.rst b/docs/source/dummy.rst index 820ca98..cef70b2 100644 --- a/docs/source/dummy.rst +++ b/docs/source/dummy.rst @@ -5,5 +5,6 @@ notebooks/standalone_flow_matching notebooks/2d_discrete_flow_matching + notebooks/2d_multimodal_flow_matching notebooks/2d_riemannian_flow_matching_flat_torus notebooks/2d_riemannian_flow_matching_sphere diff --git a/docs/source/flow_matching.solver.rst b/docs/source/flow_matching.solver.rst index 99b00e8..dd8ebd1 100644 --- a/docs/source/flow_matching.solver.rst +++ b/docs/source/flow_matching.solver.rst @@ -14,5 +14,6 @@ Solvers Solver ODESolver MixtureDiscreteEulerSolver + MultimodalSolver RiemannianODESolver diff --git a/docs/source/flow_matching.utils.multimodal.rst b/docs/source/flow_matching.utils.multimodal.rst new file mode 100644 index 0000000..024e132 --- /dev/null +++ b/docs/source/flow_matching.utils.multimodal.rst @@ -0,0 +1,18 @@ +``flow_matching.utils.multimodal`` +============================= + +.. currentmodule:: flow_matching.utils.multimodal + + +Flow +-------------------------------- + +Generic multimodal flow class + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + Flow + diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 093361a..c3bc258 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -10,3 +10,4 @@ API Reference flow_matching.solver flow_matching.utils.model_wrapper flow_matching.utils.manifolds + flow_matching.utils.multimodal diff --git a/docs/source/notebooks.rst b/docs/source/notebooks.rst index 1967e99..5800b6a 100644 --- a/docs/source/notebooks.rst +++ b/docs/source/notebooks.rst @@ -29,4 +29,10 @@ Notebooks :image: _static/riemannian_torus.png :link: notebooks/2d_riemannian_flow_matching_flat_torus.html +.. customcarditem:: + :header: Multimodal Flow Matching + :card_description: Train and sample from a 2D Multimodal Flow Matching model. + :image: _static/multimodal.png + :link: notebooks/2d_multimodal_flow_matching.html + .. customcardend:: diff --git a/examples/2d_multimodal_flow_matching.ipynb b/examples/2d_multimodal_flow_matching.ipynb new file mode 100644 index 0000000..2a31ec1 --- /dev/null +++ b/examples/2d_multimodal_flow_matching.ipynb @@ -0,0 +1,667 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "381ea8d5", + "metadata": {}, + "source": [ + "# A simple 2D Multimodal Flow Matching model" + ] + }, + { + "cell_type": "markdown", + "id": "0c0a75af", + "metadata": {}, + "source": [ + "This notebook trains and evaluates a multimodal FM model that jointly handles\n", + "a discrete modality (categorical data) and a continuous modality (real‑valued 2‑D data).\n", + "\n", + "Dataset: 2D discrete/continuous checkerboard\n", + "Model (probability denoiser/velocity): MLPs for each modality and a shared Transformer trunk" + ] + }, + { + "cell_type": "markdown", + "id": "b5c941fc", + "metadata": {}, + "source": [ + "## Imports and init device" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e7758331", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "from typing import Any, Dict, List, Sequence\n", + "\n", + "# visualization\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from flow_matching.path import AffineProbPath, MixtureDiscreteProbPath\n", + "from flow_matching.path.scheduler import (\n", + " CondOTScheduler, # continuous scheduler (training)\n", + " PolynomialConvexScheduler, # discrete scheduler (training)\n", + ")\n", + "\n", + "# flow_matching\n", + "from flow_matching.utils.multimodal import Flow\n", + "from torch import nn, Tensor\n", + "\n", + "# To avoid meshgrid warning\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "10957ca3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using MPS\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " device = \"cuda:0\"\n", + " print(\"Using GPU\")\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + " print(\"Using MPS\")\n", + "else:\n", + " device = \"cpu\"\n", + " print(\"Using CPU\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0491f488", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(42)" + ] + }, + { + "cell_type": "markdown", + "id": "b2ff4e5f", + "metadata": {}, + "source": [ + "## Shared model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1321dec5", + "metadata": {}, + "outputs": [], + "source": [ + "class SharedTransformer(nn.Module):\n", + " \"\"\"\n", + " Shared Transformer trunk used by both modalities.\n", + "\n", + " Args:\n", + " hidden_dim (int): The hidden dimension of the model.\n", + " nhead (int): The number of attention heads.\n", + " num_layers (int): The number of TransformerEncoder layers.\n", + " \"\"\"\n", + "\n", + " def __init__(self, hidden_dim: int = 128, nhead: int = 4, num_layers: int = 2):\n", + " super().__init__()\n", + " encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead)\n", + " self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Forward pass through the shared Transformer.\n", + "\n", + " Args:\n", + " x (Tensor): Input tensor of shape (sequence_length, batch_size, hidden_dim).\n", + "\n", + " Returns:\n", + " Tensor: Output tensor of the same shape as input.\n", + " \"\"\"\n", + " return self.transformer(x)" + ] + }, + { + "cell_type": "markdown", + "id": "af22ef56", + "metadata": {}, + "source": [ + "## Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b2466329", + "metadata": {}, + "outputs": [], + "source": [ + "def inf_train_gen_discrete(\n", + " n_grid_points: int = 128,\n", + " batch_size: int = 200,\n", + " device: str = \"cpu\",\n", + ") -> Tensor:\n", + " \"\"\"\n", + " Generate a batch of discrete (categorical) samples.\n", + " Returns a tensor of shape (batch, 2) with integer token IDs.\n", + "\n", + " Args:\n", + " n_grid_points (int): Number of grid points along one axis (should be divisible by 4).\n", + " batch_size (int): Number of samples to generate.\n", + " device (str): Device to place the tensor on.\n", + "\n", + " Returns:\n", + " Tensor: A tensor of shape (batch_size, 2) with integer token IDs.\n", + " \"\"\"\n", + " assert n_grid_points % 4 == 0, \"grid size must be divisible by 4\"\n", + " n_grid_points //= 4\n", + "\n", + " x1 = torch.randint(low=0, high=n_grid_points * 4, size=(batch_size,), device=device)\n", + " samples_x2 = torch.randint(\n", + " low=0, high=n_grid_points, size=(batch_size,), device=device\n", + " )\n", + "\n", + " x2 = (\n", + " samples_x2\n", + " + 2 * n_grid_points\n", + " - torch.randint(low=0, high=2, size=(batch_size,), device=device)\n", + " * 2\n", + " * n_grid_points\n", + " + (torch.floor(x1 / n_grid_points) % 2) * n_grid_points\n", + " )\n", + " return torch.stack([x1, x2], dim=1).long()\n", + "\n", + "\n", + "def inf_train_gen_continuous(batch_size: int = 200, device: str = \"cpu\") -> Tensor:\n", + " \"\"\"\n", + " Generate a batch of 2-D continuous points from a checkerboard-like distribution.\n", + " Returns a tensor of shape (batch, 2).\n", + "\n", + " Args:\n", + " batch_size (int): Number of samples to generate.\n", + " device (str): Device to place the tensor on.\n", + "\n", + " Returns:\n", + " Tensor: A tensor of shape (batch_size, 2) with continuous values.\n", + " \"\"\"\n", + " x1 = torch.rand(batch_size, device=device) * 4 - 2\n", + " x2_ = (\n", + " torch.rand(batch_size, device=device)\n", + " - torch.randint(high=2, size=(batch_size,), device=device) * 2\n", + " )\n", + " x2 = x2_ + (torch.floor(x1) % 2)\n", + " data = torch.stack([x1, x2], dim=1) / 0.45\n", + " return data.float()" + ] + }, + { + "cell_type": "markdown", + "id": "e1faf8fd", + "metadata": {}, + "source": [ + "## Unified multimodal model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a3517fbc", + "metadata": {}, + "outputs": [], + "source": [ + "class Swish(nn.Module):\n", + " \"\"\"Swish activation (x * sigmoid(x)).\"\"\"\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"Forward pass through the Swish activation.\"\"\"\n", + " return torch.sigmoid(x) * x\n", + "\n", + "\n", + "class TransformerModel(nn.Module):\n", + " \"\"\"\n", + " A unified Transformer-based model for handling multiple modalities.\n", + "\n", + " This model processes a sequence of modalities, each with its own input\n", + " and output heads, while sharing a central Transformer trunk. It is designed\n", + " to be flexible for both discrete (categorical) and continuous data types.\n", + "\n", + " Args:\n", + " shared_transformer (SharedTransformer): The shared TransformerEncoder module.\n", + " modality_configs (List[Dict[str, Any]]): A list of dictionaries, each configuring a modality.\n", + " Required keys per config:\n", + " - 'type': 'discrete' or 'continuous'.\n", + " - 'length': The sequence length for this modality's tokens.\n", + " If 'type' is 'discrete':\n", + " - 'vocab_size': The size of the vocabulary.\n", + " If 'type' is 'continuous':\n", + " - 'input_dim': The feature dimension of the continuous data.\n", + " time_dim (int): The dimension of the time embedding.\n", + " hidden_dim (int): The hidden dimension of the model and transformer.\n", + "\n", + " Raises:\n", + " ValueError: If an unknown modality type is provided.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " shared_transformer: SharedTransformer,\n", + " modality_configs: List[Dict[str, Any]],\n", + " time_dim: int = 1,\n", + " hidden_dim: int = 128,\n", + " ):\n", + " super().__init__()\n", + " self.shared = shared_transformer\n", + " self.modality_configs = modality_configs\n", + " self.seq_lengths = [config[\"length\"] for config in modality_configs]\n", + "\n", + " self.input_embedders = nn.ModuleList()\n", + " self.time_embedders = nn.ModuleList()\n", + " self.input_projectors = nn.ModuleList()\n", + " self.output_heads = nn.ModuleList()\n", + " self.activations = nn.ModuleList()\n", + "\n", + " for config in self.modality_configs:\n", + " self.time_embedders.append(nn.Linear(1, time_dim))\n", + " self.input_projectors.append(nn.Linear(hidden_dim + time_dim, hidden_dim))\n", + " self.activations.append(Swish())\n", + "\n", + " if config[\"type\"] == \"discrete\":\n", + " self.input_embedders.append(\n", + " nn.Embedding(config[\"vocab_size\"], hidden_dim)\n", + " )\n", + " self.output_heads.append(nn.Linear(hidden_dim, config[\"vocab_size\"]))\n", + " elif config[\"type\"] == \"continuous\":\n", + " self.input_embedders.append(nn.Linear(config[\"input_dim\"], hidden_dim))\n", + " self.output_heads.append(nn.Linear(hidden_dim, config[\"input_dim\"]))\n", + " else:\n", + " raise ValueError(f\"Unknown modality type: {config['type']}\")\n", + "\n", + " def forward(\n", + " self, x_modalities: Sequence[Tensor], t_modalities: Sequence[Tensor]\n", + " ) -> Sequence[Tensor]:\n", + " \"\"\"\n", + " Forward pass for multiple modalities.\n", + "\n", + " Args:\n", + " x_modalities (Sequence[Tensor]): A sequence of input tensors, one for each modality.\n", + " Shape for discrete: (batch, length)\n", + " Shape for continuous: (batch, input_dim)\n", + " t_modalities (Sequence[Tensor]): A sequence of time tensors, one for each modality.\n", + " Shape for all: (batch, 1)\n", + "\n", + " Returns:\n", + " Sequence[Tensor]: A sequence of output tensors, one for each modality.\n", + " \"\"\"\n", + " embeddings = []\n", + "\n", + " # 1. Process each modality through its specific input head\n", + " for i, (x, t, config) in enumerate(\n", + " zip(x_modalities, t_modalities, self.modality_configs)\n", + " ):\n", + " # Embed time and expand to match sequence length\n", + " t_emb = self.time_embedders[i](t.unsqueeze(-1))\n", + " t_emb = t_emb.unsqueeze(1).expand(-1, config[\"length\"], -1)\n", + "\n", + " # Embed input based on modality type\n", + " if config[\"type\"] == \"discrete\":\n", + " x_emb = self.input_embedders[i](x) # (B, length, hidden_dim)\n", + " else: # continuous\n", + " x_emb = self.input_embedders[i](x) # (B, hidden_dim)\n", + " x_emb = x_emb.unsqueeze(1) # (B, 1, hidden_dim)\n", + "\n", + " # Combine, project, and activate\n", + " combined = torch.cat([x_emb, t_emb], dim=-1)\n", + " h = self.input_projectors[i](combined)\n", + " h = self.activations[i](h)\n", + "\n", + " # Prepare for transformer (seq_len, batch, hidden_dim)\n", + " embeddings.append(h.permute(1, 0, 2))\n", + "\n", + " # 2. Concatenate all modality embeddings and pass through shared transformer\n", + " full_sequence = torch.cat(embeddings, dim=0)\n", + " transformer_out = self.shared(full_sequence)\n", + "\n", + " # 3. Split the output and process through specific output heads\n", + " output_chunks = torch.split(transformer_out, self.seq_lengths, dim=0)\n", + " results = []\n", + " for i, chunk in enumerate(output_chunks):\n", + " # (length, B, hidden_dim) -> (B, length, hidden_dim)\n", + " chunk = chunk.permute(1, 0, 2)\n", + " output = self.output_heads[i](chunk)\n", + "\n", + " # Squeeze sequence dimension if it's 1 (for continuous case)\n", + " if output.size(1) == 1:\n", + " output = output.squeeze(1)\n", + " results.append(output)\n", + "\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "d5378557", + "metadata": {}, + "source": [ + "## Instantiate modalities and model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9b0e8daa", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- General Hyperparameters -----------------------------------------\n", + "length = 2 # 2 tokens per sample\n", + "vocab_size = 128\n", + "added_token = 0 # uniform source distribution → no extra token\n", + "vocab_size += added_token\n", + "hidden_dim = 128\n", + "\n", + "# ---- Shared transformer trunk ----------------------------------------\n", + "shared_transformer = SharedTransformer(hidden_dim=hidden_dim, nhead=4, num_layers=2).to(\n", + " device\n", + ")\n", + "\n", + "# ---- Model and Path Configuration ------------------------------------\n", + "modality_configs = [\n", + " {\n", + " \"type\": \"discrete\",\n", + " \"vocab_size\": vocab_size,\n", + " \"length\": length,\n", + " },\n", + " {\n", + " \"type\": \"continuous\",\n", + " \"input_dim\": length,\n", + " \"length\": 1, # This modality is treated as a single token in the sequence\n", + " },\n", + "]\n", + "\n", + "# A unified model that handles all modalities\n", + "model = TransformerModel(\n", + " shared_transformer=shared_transformer,\n", + " modality_configs=modality_configs,\n", + " time_dim=1,\n", + " hidden_dim=hidden_dim,\n", + ").to(device)\n", + "\n", + "# Path definitions remain distinct per modality\n", + "discrete_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=2.0))\n", + "continuous_path = AffineProbPath(scheduler=CondOTScheduler())\n", + "\n", + "# ---- Assemble modalities dict for Flow -------------------------------\n", + "modalities = {\n", + " \"discrete\": {\n", + " \"path\": discrete_path,\n", + " # loss omitted → Flow will use MixturePathGeneralizedKL automatically\n", + " },\n", + " \"continuous\": {\n", + " \"path\": continuous_path,\n", + " # loss omitted → Flow will use MSE loss automatically\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b82a25cc", + "metadata": {}, + "source": [ + "## Instantiate the multimodal Flow model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9f2ccedd", + "metadata": {}, + "outputs": [], + "source": [ + "flow = Flow(model=model, modalities=modalities)\n", + "\n", + "# Optimizer (optimises both modality models)\n", + "optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "2636f3a4", + "metadata": {}, + "source": [ + "## Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "646de9a8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| iter 3000 | 14.35 ms/step | loss 9.040 \n", + "| iter 6000 | 16.78 ms/step | loss 9.292 \n", + "| iter 9000 | 18.14 ms/step | loss 9.037 \n", + "| iter 12000 | 18.66 ms/step | loss 9.878 \n", + "| iter 15000 | 18.56 ms/step | loss 9.466 \n", + "| iter 18000 | 18.55 ms/step | loss 9.251 \n", + "| iter 21000 | 18.27 ms/step | loss 9.220 \n", + "| iter 24000 | 18.40 ms/step | loss 9.489 \n", + "| iter 27000 | 18.48 ms/step | loss 9.835 \n", + "| iter 30000 | 18.33 ms/step | loss 9.114 \n" + ] + } + ], + "source": [ + "lr = 1e-3\n", + "batch_size = 2048\n", + "iterations = 30001\n", + "print_every = 3000\n", + "epsilon = 1e-3\n", + "\n", + "source_distribution = \"uniform\" # for the discrete modality\n", + "\n", + "start_time = time.time()\n", + "for i in range(iterations):\n", + " optimizer.zero_grad()\n", + "\n", + " # ---- Discrete data -------------------------------------------------\n", + " x1_disc = inf_train_gen_discrete(\n", + " n_grid_points=vocab_size - added_token,\n", + " batch_size=batch_size,\n", + " device=device,\n", + " )\n", + " if source_distribution == \"uniform\":\n", + " x0_disc = torch.randint_like(x1_disc, high=vocab_size)\n", + " else: # mask case (not used here)\n", + " raise NotImplementedError\n", + "\n", + " # ---- Continuous data -----------------------------------------------\n", + " x1_cont = inf_train_gen_continuous(batch_size=batch_size, device=device)\n", + " x0_cont = torch.randn_like(x1_cont) # isotropic Gaussian prior\n", + "\n", + " # ---- Sample a common time tensor for both modalities ---------------\n", + " t = torch.rand(batch_size, device=device) * (1 - epsilon)\n", + "\n", + " # ---- Sample from each path to obtain x_t ---------------------------\n", + " disc_path_sample = discrete_path.sample(t=t, x_0=x0_disc, x_1=x1_disc)\n", + " cont_path_sample = continuous_path.sample(t=t, x_0=x0_cont, x_1=x1_cont)\n", + "\n", + " # ---- Build the inputs expected by Flow.training_loss -----------\n", + " x_1 = [x1_disc, x1_cont]\n", + " x_t = [disc_path_sample.x_t, cont_path_sample.x_t]\n", + " dx_t = [None, cont_path_sample.dx_t] # NOTE: dx_t is None for discrete\n", + " ts = [t] * 2 # NOTE: For now, both modalities share the same time\n", + "\n", + " # ---- Compute total loss and back‑propagate -------------------------\n", + " loss, _ = flow.training_loss(x_1=x_1, x_t=x_t, dx_t=dx_t, t=ts)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # ---- Logging -------------------------------------------------------\n", + " if (i + 1) % print_every == 0:\n", + " elapsed = time.time() - start_time\n", + " print(\n", + " f\"| iter {i+1:6d} | {elapsed*1000/print_every:5.2f} ms/step | loss {loss.item():8.3f} \"\n", + " )\n", + " start_time = time.time()" + ] + }, + { + "cell_type": "markdown", + "id": "e87e944d", + "metadata": {}, + "source": [ + "## Sampling from the trained multimodal model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e2aab2d8", + "metadata": {}, + "outputs": [], + "source": [ + "x_init = [\n", + " torch.randint_like(\n", + " x1_disc, high=vocab_size\n", + " ), # discrete initial state (uniform categorical)\n", + " torch.randn_like(x1_cont), # continuous initial state (Gaussian noise)\n", + "]\n", + "\n", + "flow.eval() # switch to eval mode for sampling\n", + "samples = flow.sample(x_init=x_init, device=device, steps=1000)" + ] + }, + { + "cell_type": "markdown", + "id": "2bceb4bb", + "metadata": {}, + "source": [ + "## Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "43dc2909", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# ---- Discrete modality -------------------------------------------------\n", + "discrete_samples = samples[0].cpu().numpy() # shape (N, 2) integer tokens\n", + "vocab = vocab_size\n", + "\n", + "# Plot a 2‑D histogram of the discrete samples\n", + "plt.figure(figsize=(6, 5))\n", + "plt.hist2d(\n", + " discrete_samples[:, 0],\n", + " discrete_samples[:, 1],\n", + " bins=vocab,\n", + " cmap=\"viridis\",\n", + ")\n", + "plt.title(\"Discrete modality samples (token histogram)\")\n", + "plt.xlabel(\"Token 1\")\n", + "plt.ylabel(\"Token 2\")\n", + "plt.colorbar(label=\"Count\")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# ---- Continuous modality -----------------------------------------------\n", + "continuous_samples = samples[1].cpu().numpy() # shape (N, 2)\n", + "\n", + "# Plot a 2‑D histogram of the continuous samples\n", + "plt.figure(figsize=(6, 5))\n", + "plt.hist2d(\n", + " continuous_samples[:, 0],\n", + " continuous_samples[:, 1],\n", + " bins=200,\n", + " cmap=\"viridis\",\n", + ")\n", + "plt.title(\"Continuous modality samples (2-D density)\")\n", + "plt.xlabel(\"x₁\")\n", + "plt.ylabel(\"x₂\")\n", + "plt.colorbar(label=\"Count\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "flow_matching", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/README.md b/examples/README.md index 3b6af05..45542bb 100644 --- a/examples/README.md +++ b/examples/README.md @@ -16,6 +16,6 @@ | [standalone_discrete_flow_matching.ipynb](standalone_discrete_flow_matching.ipynb) | A concise discrete flow matching example built in pure PyTorch. | | [2d_flow_matching.ipynb](2d_flow_matching.ipynb) | 2D flow matching example on the checkerboard dataset using the flow_matching library. | | [2d_discrete_flow_matching.ipynb](2d_discrete_flow_matching.ipynb) | 2D discrete flow matching example on the checkerboard dataset using the flow_matching library. | +| [2d_multimodal_flow_matching.ipynb](2d_multimodal_flow_matching.ipynb) | 2D multimodal (discrete-continuous) flow matching on the checkerboard dataset and the flow_matching library. | | [2d_riemannian_flow_matching_flat_torus.ipynb](2d_riemannian_flow_matching_flat_torus.ipynb) | 2D Riemannian flow matching on a flat torus on the checkerboard dataset and the flow_matching library. | | [2d_riemannian_flow_matching_sphere.ipynb](2d_riemannian_flow_matching_sphere.ipynb) | 2D Riemannian flow matching on a sphere on the checkerboard dataset and the flow_matching library. | - diff --git a/flow_matching/solver/__init__.py b/flow_matching/solver/__init__.py index 6bd7b01..3a3bedb 100644 --- a/flow_matching/solver/__init__.py +++ b/flow_matching/solver/__init__.py @@ -14,5 +14,6 @@ "Solver", "ModelWrapper", "MixtureDiscreteEulerSolver", + "MultimodalSolver", "RiemannianODESolver", ] diff --git a/flow_matching/solver/multimodal_solver.py b/flow_matching/solver/multimodal_solver.py new file mode 100644 index 0000000..029e1a6 --- /dev/null +++ b/flow_matching/solver/multimodal_solver.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from math import ceil +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torch.nn import functional as F + +from flow_matching.path import MixtureDiscreteProbPath +from flow_matching.solver.solver import Solver +from flow_matching.solver.utils import get_nearest_times +from flow_matching.utils import categorical, expand_tensor_like, ModelWrapper + +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + + +class MultimodalSolver(Solver): + """Solver for multiple continuous and discrete data modalities. + + This solver handles an arbitrary number of modalities, which can be either + continuous or discrete. Each modality has its own state tensor. + All modalities share the same time discretization and are updated + simultaneously at each step. + + For continuous modalities, an Euler integration step is used. For discrete + modalities, the update follows the procedure from `MixtureDiscreteEulerSolver`. + + Args: + model (Union[ModelWrapper, Callable]): + A model that receives a sequence of state tensors + (one per modality) as ``x`` and a scalar time tensor ``t``, + and returns a sequence of output tensors. For continuous modalities, + the output is a velocity. For discrete modalities, it is the + posterior probability `p_1t`. + modality_configs (List[Dict[str, Any]]): + A list of configuration dictionaries, one for each modality. + Each dictionary must have a ``'type'`` key, which is either + ``'continuous'`` or ``'discrete'``. Discrete modality configs may + provide a ``'dtype_categorical'`` key with the desired data type + for categorical logit sampling (e.g., ``torch.float32``) and + must provide a ``'path'`` key with a `MixtureDiscreteProbPath` + instance. Continuous modality configs must provide a ``'path'`` + key with a `ProbPath` instance + (e.g., `AffineProbPath(scheduler=CondOTScheduler())`) as well as + an ``'x_1_prediction'`` key which is either ``True`` or ``False``. + If ``True``, the model is expected to predict the clean data `x_1` + for that modality, and such predictions will be reparameterized + as velocities during the sampling process. If ``False``, the model + is expected to predict the velocities directly. + source_distribution_p (Optional[Tensor], optional): Source distribution, + must be of shape [vocabulary_size]. Required only when divergence-free + term for the probability velocity is non-zero. Defaults to None. + model_sampling_fn (str, optional): If ``model`` is a class instance + with multiple methods, this specifies the method to use for + forward passes during sampling. Defaults to ``"forward"``. + + Raises: + TypeError: If ``model`` is not callable or if ``modality_configs`` + is not a list of dictionaries. + """ + + def __init__( + self, + model: Union[ModelWrapper, Callable], + modality_configs: List[Dict[str, Any]], + source_distribution_p: Optional[Tensor] = None, + model_sampling_fn: str = "forward", + ): + super().__init__() + if not callable(model): + raise TypeError(f"model must be callable, got {type(model)}") + self.model = model + self.modality_configs = modality_configs + self.source_distribution_p = source_distribution_p + self.model_sampling_fn = model_sampling_fn + + self._validate_configs() + + def _validate_configs(self): + """Validates the modality configurations.""" + if not isinstance(self.modality_configs, list): + raise TypeError("modality_configs must be a list of dictionaries.") + for i, config in enumerate(self.modality_configs): + if not isinstance(config, dict): + raise TypeError(f"Config for modality {i} must be a dictionary.") + if "type" not in config: + raise ValueError(f"Config for modality {i} must have a 'type' key.") + if config["type"] not in ["continuous", "discrete"]: + raise ValueError( + f"Unsupported modality type '{config['type']}' for modality {i}." + ) + if config["type"] == "discrete": + if "path" not in config: + raise ValueError( + f"Discrete modality {i} requires a 'path' in its config." + ) + if not isinstance(config["path"], MixtureDiscreteProbPath): + raise TypeError( + f"'path' for discrete modality {i} must be a MixtureDiscreteProbPath instance." + ) + if config["type"] == "continuous": + if "path" not in config: + raise ValueError( + f"Continuous modality {i} requires a 'path' in its config." + ) + if "x_1_prediction" not in config: + raise ValueError( + f"Continuous modality {i} requires an 'x_1_prediction' key in its config." + ) + if not isinstance(config["x_1_prediction"], bool): + raise TypeError( + f"'x_1_prediction' for continuous modality {i} must be a boolean." + ) + + def sample( + self, + x_init: Sequence[Tensor], + step_size: Optional[float], + div_free: Union[float, Callable[[float], float]] = 0.0, + method: str = "euler", + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + enable_grad: bool = False, + verbose: bool = False, + **model_extras: dict, + ) -> Union[Sequence[Tensor], Sequence[List[Tensor]]]: + """Sample all modalities simultaneously. + + Args: + x_init (Sequence[Tensor]): Initial states for each modality. + step_size (Optional[float]): Fixed step size for uniform discretization. + If ``None``, the discretization is taken from ``time_grid``. + div_free (Union[float, Callable[[float], float]]): The coefficient + of the divergence-free term in the probability velocity + (for discrete modalities). Can be either a float or a time + dependent function. Defaults to 0.0. + method (str): Numerical integration method. Currently only ``"euler"`` is + supported, representing a single forward step. + time_grid (Tensor): Tensor of time points defining the interval. + return_intermediates (bool): If ``True``, returns a list of tensors for + each modality containing the state at each intermediate time step. + enable_grad (bool): Whether to enable gradient tracking during integration. + verbose (bool): If ``True``, displays a progress bar during sampling. + **model_extras (dict): Additional arguments passed to the model. + + Raises: + ValueError: If the number of initial states does not match the number of + modality configurations. + NotImplementedError: If an unsupported integration method is specified. + ImportError: If ``verbose`` is ``True`` but ``tqdm`` is not installed. + TypeError: If the model's output does not match the expected format. + + Returns: + Union[Sequence[Tensor], Sequence[List[Tensor]]]: If ``return_intermediates`` is + ``False`` (default), returns a list of final state tensors, one per + modality. If ``True``, returns a list where each element is another + list of tensors representing the trajectory for a modality. + """ + if len(x_init) != len(self.modality_configs): + raise ValueError( + "Number of initial states must match the number of modality configurations." + ) + if method != "euler": + raise NotImplementedError( + f"Method '{method}' is not implemented for MultimodalSolver." + ) + if not div_free == 0.0: + assert ( + self.source_distribution_p is not None + ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity for each discrete modality." + + # Initialize the current state `x_t` with the initial state `X_0`. + device = x_init[0].device + batch_size = x_init[0].shape[0] + time_grid = time_grid.to(device) + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [t_init + step_size * i for i in range(n_steps)] + [t_final], + device=device, + ) + + if return_intermediates: + # Get order of intermediate steps + order = torch.argsort(time_grid) + # Compute intermediate steps to return via nearest points in t_discretization to time_grid. + time_grid = get_nearest_times( + time_grid=time_grid, t_discretization=t_discretization + ) + + states: Sequence[Tensor] = [(x if enable_grad else x.clone()) for x in x_init] + intermediates: Sequence[List[Tensor]] = ( + [[x if enable_grad else x.clone()] for x in x_init] + if return_intermediates + else [] + ) + + steps_counter = 0 + + if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) + ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") + else: + ctx = nullcontext() + + with ctx, torch.set_grad_enabled(enable_grad): + for i in range(n_steps): + # NOTE: For now, all modalities share the same time + t = [t_discretization[i : i + 1].repeat(batch_size)] * len(states) + h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] + + model_fn = getattr(self.model, self.model_sampling_fn, self.model) + outputs = model_fn(states, t, **model_extras) + + if not isinstance(outputs, (list, tuple)) or len(outputs) != len( + states + ): + raise TypeError( + "The model must return a sequence of tensors matching the number of modalities." + ) + + for idx, config in enumerate(self.modality_configs): + model_output = outputs[idx] + + t_expanded = expand_tensor_like( + input_tensor=t[idx], + expand_to=model_output, + ) + + if config["type"] == "continuous": + # Sample x_{t+h} = x_t + h * v(x_t,t) + path = config["path"] + velocity_output = ( + path.target_to_velocity( + x_1=model_output, x_t=states[idx], t=t_expanded + ) + if config["x_1_prediction"] + else model_output + ) + + states[idx] = states[idx] + h * velocity_output + + elif config["type"] == "discrete": + dtype = config.get("dtype_categorical", torch.float32) + + # Sample x_1 ~ p_1|t( \cdot |x_t) + p_1t = torch.softmax(model_output, dim=-1) + x_1 = categorical(p_1t.to(dtype=dtype)) + + # Checks if final step + if i == n_steps - 1: + states[idx] = x_1 # x_t = x_1 at final step + else: + vocabulary_size = p_1t.shape[-1] + if self.source_distribution_p is not None: + assert self.source_distribution_p.shape == torch.Size( + [vocabulary_size] + ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {self.source_distribution_p.shape}." + + # Compute u_t(x|x_t,x_1) + path: MixtureDiscreteProbPath = config["path"] + scheduler_output = path.scheduler(t=t_expanded) + + k_t = scheduler_output.alpha_t + d_k_t = scheduler_output.d_alpha_t + + delta_1 = F.one_hot(x_1, num_classes=vocabulary_size).to( + k_t.dtype + ) + u = d_k_t / (1 - k_t) * delta_1 + + # Add divergence-free part + div_free_t = ( + div_free(t_expanded) if callable(div_free) else div_free + ) + + if div_free_t > 0: + p_0 = self.source_distribution_p[ + (None,) * states[idx].dim() + ] + u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( + (1 - k_t) * p_0 + k_t * delta_1 + ) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot( + states[idx], num_classes=vocabulary_size + ) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand( + size=states[idx].shape, device=states[idx].device + ) < 1 - torch.exp(-h * intensity) + + if mask_jump.sum() > 0: + states[idx][mask_jump] = categorical( + u[mask_jump].to(dtype=dtype) + ) + + # Increment time for each modality + t[idx] = t[idx] + h + + steps_counter += 1 + + if return_intermediates: + for idx, s in enumerate(states): + if t[idx] in time_grid: + intermediates[idx].append(s if enable_grad else s.clone()) + + if verbose: + ctx.n = (torch.cat(t) * n_steps).mean().long().item() + ctx.refresh() + ctx.set_description(f"NFE: {steps_counter}") + + if return_intermediates: + if step_size is None: + return intermediates + else: + return [ + [intermediates[idx][i] for i in order] + for idx in range(len(intermediates)) + ] + else: + return states diff --git a/flow_matching/utils/multimodal.py b/flow_matching/utils/multimodal.py new file mode 100644 index 0000000..f239aa8 --- /dev/null +++ b/flow_matching/utils/multimodal.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import torch +from torch import nn, Tensor + +# flow_matching +from flow_matching.loss.generalized_loss import MixturePathGeneralizedKL +from flow_matching.path.mixture import MixtureDiscreteProbPath +from flow_matching.solver.multimodal_solver import MultimodalSolver + + +MULTIMODAL_METHOD = Literal["euler"] + + +def _default_continuous_loss( + pred: Tensor, target: Tensor, reduction: str = "none" +) -> Tensor: + """ + Squared error loss for continuous modalities. + + Args: + pred (Tensor): predicted velocity field. + target (Tensor): target velocity field. + reduction (str): reduction method, one of 'mean', 'sum', or 'none'. + + Raises: + ValueError: if reduction is not one of 'none', 'mean', or 'sum'. + + Returns: + Tensor: computed loss. + """ + loss = (pred - target) ** 2 + + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + elif reduction == "none": + return loss + else: + raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") + + +class Flow(nn.Module): + """ + Generic multimodal flow matching model. + + This class aggregates multiple modalities, each with its own model, path, + scheduler, and loss. It provides utilities for training (computing the total + loss) and inference (sampling) across all modalities. + + Args: + model (nn.Module): + A model that receives a sequence of state tensors + (one per modality) as ``x`` and a scalar time tensor ``t``, + and returns a sequence of output tensors. For continuous modalities, + the output is a velocity. For discrete modalities, it is the + posterior probability `p_1t`. + modalities (Dict[str, Dict[str, Any]]): + Mapping from modality name to a dict with keys: + - "path": A probability path object (e.g., MixtureDiscreteProbPath for discrete data, + or any continuous path implementation). + - "loss" (optional): A callable loss function. If omitted, a default loss is chosen + based on the path type. + - "weight" (optional): A float weight for the modality's training loss. Defaults to 1.0. + - "x_1_prediction" (continuous paths only, optional): If True, the model is expected to predict + the clean data `x_1` for that modality, and such predictions will be reparameterized + as velocities during the sampling process. If False, the model is expected to predict + the velocities directly. Defaults to False. + model_sampling_fn (str, optional): If ``model`` is a class instance + with multiple methods, this specifies the method to use for + forward passes during sampling. Defaults to ``"forward"``. + """ + + def __init__( + self, + model: nn.Module, + modalities: Dict[str, Dict[str, Any]], + model_sampling_fn: str = "forward", + ) -> None: + super().__init__() + self.model = model + self.paths: Dict[str, Any] = {} + self.loss_fns: Dict[str, Callable] = {} + self.loss_weights: Dict[str, float] = {} + + for name, spec in modalities.items(): + path = spec["path"] + self.paths[name] = path + + # Choose loss function + loss_fn = spec.get("loss") + if loss_fn is None: + if isinstance(path, MixtureDiscreteProbPath): + loss_fn = MixturePathGeneralizedKL(path, reduction="none") + else: + loss_fn = _default_continuous_loss + self.loss_fns[name] = loss_fn + self.loss_weights[name] = spec.get("weight", 1.0) + + # Set up Euler solver for each modality. + self.modality_configs = [ + { + "name": name, + "type": ( + "discrete" + if isinstance(path, MixtureDiscreteProbPath) + else "continuous" + ), + "path": path, + "x_1_prediction": modalities[name].get("x_1_prediction", False), + } + for name, path in self.paths.items() + ] + self.solver = MultimodalSolver( + model=self.model, + modality_configs=self.modality_configs, + model_sampling_fn=model_sampling_fn, + ) + + def training_loss( + self, + x_1: Sequence[Tensor], + x_t: Sequence[Tensor], + dx_t: Sequence[Tensor], + t: Sequence[Tensor], + model_output: Optional[Sequence[Tensor]] = None, + detach_loss_dict: bool = True, + **model_extras: dict, + ) -> Tuple[Sequence[Tensor], Dict[str, Tensor]]: + """ + Compute the total training loss across all modalities. + + Args: + x_1 (Sequence[Tensor]): Sequence of tensors, one per modality, + containing the data at time 1. + x_t (Sequence[Tensor]): Sequence of tensors, one per modality, + containing the data at time t. + dx_t (Sequence[Tensor]): Sequence of tensors, one per modality, + containing the velocity field at time t. + t (Sequence[Tensor]): Sequence of tensors, one per modality, + containing the time values. + model_output (Optional[Sequence[Tensor]]): Optional precomputed model outputs. + If provided, these are used instead of calling the model. + detach_loss_dict (bool): If ``True``, detaches individual modality losses + from the computation graph when storing them in the loss dictionary. + Defaults to ``True``. + **model_extras (dict): Additional keyword arguments to pass to the model. + + Returns: + Tuple[Sequence[Tensor], Dict[str, Tensor]]: + Scalar loss (sum of modality losses) and a dictionary + of individual modality losses. + """ + assert ( + len(x_1) == len(x_t) == len(dx_t) == len(t) == len(self.paths) + ), "Input sequences must match the number of modalities." + + if model_output is not None: + assert len(model_output) == len( + self.paths + ), "If provided, model outputs must match the number of modalities." + + loss_dict = {} + total_loss = 0.0 + + model_output = model_output or self.model(x_t, t, **model_extras) + + for i, name in enumerate(self.paths): + path = self.paths[name] + loss_fn = self.loss_fns[name] + modality_config = self.modality_configs[i] + + if isinstance(path, MixtureDiscreteProbPath): + # Discrete case: model should output logits. + assert x_t[i].dtype == torch.long, ( + f"Expected integer tensor for discrete modality '{name}', " + f"got {x_t[i].dtype}", + ) + loss = loss_fn(model_output[i], x_1[i], x_t[i], t[i]) + else: + # Continuous case: model returns velocity field. + assert x_t[i].is_floating_point(), ( + f"Expected float tensor for continuous modality '{name}', " + f"got {x_t[i].dtype}", + ) + loss = loss_fn( + model_output[i], + x_1[i] if modality_config["x_1_prediction"] else dx_t[i], + ) + + weight = self.loss_weights[name] + loss_dict[name] = (loss.detach() if detach_loss_dict else loss) * weight + total_loss = total_loss + loss.mean() * weight + + return total_loss, loss_dict + + def sample( + self, + x_init: Sequence[Tensor], + time_grid: Optional[Tensor] = None, + device: torch.device = torch.device("cpu"), + steps: int = 1000, + step_size: Optional[float] = None, + div_free: Union[float, Callable[[float], float]] = 0.0, + method: MULTIMODAL_METHOD = "euler", + return_intermediates: bool = False, + enable_grad: bool = False, + verbose: bool = False, + **model_extras: dict, + ) -> Union[Sequence[Tensor], Sequence[List[Tensor]]]: + """ + Generate samples for each modality using the inference scheduler. + + Args: + x_init (Sequence[Tensor]): + Sequence of tensors, one per modality, containing the initial states at time 0. + For continuous modalities, this is typically Gaussian noise. + For discrete modalities, this is typically samples from a uniform categorical distribution. + time_grid (Optional[Tensor]): Optional tensor of time points defining the interval. + If provided, it overrides the uniform discretization defined by `steps`. + device (torch.device, optional): Device on which to run the sampling. + steps (int, optional): Number of integration steps for the ODE solver. + step_size (Optional[float]): Fixed step size for uniform discretization. + If ``None``, the step size is computed from ``steps``. + div_free (Union[float, Callable[[float], float]]): The coefficient + of the divergence-free term in the probability velocity + (for discrete modalities). Can be either a float or a time + dependent function. Defaults to 0.0. + method (MULTIMODAL_METHOD): Numerical integration method. Currently only ``"euler"`` is + supported, representing a single forward step. + return_intermediates (bool): If ``True``, returns a list of tensors for + each modality containing the state at each intermediate time step. + enable_grad (bool): Whether to enable gradient tracking during integration. + verbose (bool): If ``True``, prints progress during sampling. + **model_extras (dict): Additional keyword arguments to pass to the model. + + Returns: + Union[Sequence[Tensor], Sequence[List[Tensor]]]: A list where each element corresponds to a modality. + Each element is either a tensor of shape ``(batch_size, ...)`` containing the samples, + or a list of tensors (if `return_intermediates` is True in `MultimodalSolver.sample`). + """ + # Validate samples for each modality. + x_init = x_init if isinstance(x_init, list) else list(x_init) + for i, name in enumerate(self.paths): + path = self.paths[name] + + if isinstance(path, MixtureDiscreteProbPath): + assert x_init[i].dtype == torch.long, ( + f"Expected integer tensor for discrete modality '{name}', " + f"got {x_init[i].dtype}", + ) + else: + assert x_init[i].is_floating_point(), ( + f"Expected float tensor for continuous modality '{name}', " + f"got {x_init[i].dtype}", + ) + + x_init[i] = x_init[i].to(device) + + # Solve to obtain multimodal samples at time 1. + step_size = step_size or (1.0 / steps) + time_grid = ( + time_grid + if time_grid is not None + else torch.linspace(0.0, 1.0, steps, device=device) + ) + + samples = self.solver.sample( + x_init=x_init, + step_size=step_size, + div_free=div_free, + method=method, + time_grid=time_grid, + return_intermediates=return_intermediates, + enable_grad=enable_grad, + verbose=verbose, + **model_extras, + ) + + return samples diff --git a/tests/solver/test_multimodal_solver.py b/tests/solver/test_multimodal_solver.py new file mode 100644 index 0000000..9a844af --- /dev/null +++ b/tests/solver/test_multimodal_solver.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. +import unittest +from unittest.mock import MagicMock + +import torch +from flow_matching.path import AffineProbPath, MixtureDiscreteProbPath +from flow_matching.path.scheduler import CondOTScheduler, PolynomialConvexScheduler + +from flow_matching.solver.multimodal_solver import MultimodalSolver +from flow_matching.utils import ModelWrapper +from torch import Tensor + + +# ---------------------------------------------------------------------- +# Helper models for continuous and discrete modalities +# ---------------------------------------------------------------------- +class ContinuousVelocityModel(ModelWrapper): + def __init__(self): + super().__init__(None) + + def forward(self, xs: list[Tensor], t: list[Tensor], **extras) -> list[Tensor]: + # xs is a list of modality states; we only have one continuous modality here. + # Return a list with the same length as xs. + return [2.0 * xs[0]] + + +class DiscreteLogitsModel(ModelWrapper): + def __init__(self, vocab_size: int): + super().__init__(None) + self.vocab_size = vocab_size + + def forward(self, xs: list[Tensor], t: list[Tensor], **extras) -> list[Tensor]: + """Produce logits that give probability 1.0 to the last class.""" + batch = xs[0].shape[0] + logits = torch.full((batch, self.vocab_size), -1e9, device=xs[0].device) + logits[:, -1] = 1e9 + return [logits] + + +# ---------------------------------------------------------------------- +# Test suite +# ---------------------------------------------------------------------- +class TestMultimodalSolver(unittest.TestCase): + def setUp(self): + # Continuous modality config (no extra args needed) + self.continuous_cfg = { + "type": "continuous", + "path": AffineProbPath(scheduler=CondOTScheduler()), + "x_1_prediction": False, + } + + # Discrete modality config + self.vocab_size = 3 + self.discrete_path = MixtureDiscreteProbPath( + scheduler=PolynomialConvexScheduler(n=2.0) + ) + self.discrete_cfg = { + "type": "discrete", + "path": self.discrete_path, + } + + # Source distribution for divergence‑free term (uniform) + self.source_p = torch.tensor([1.0 / self.vocab_size] * self.vocab_size) + + # Dummy models + self.continuous_model = ContinuousVelocityModel() + self.discrete_model = DiscreteLogitsModel(vocab_size=self.vocab_size) + + # Combined model that forwards to the appropriate sub‑model + class CombinedModel(ModelWrapper): + def __init__(self, cont, disc): + super().__init__(None) + self.cont = cont + self.disc = disc + + def forward(self, xs, t, **extras): + # xs[0] -> continuous, xs[1] -> discrete + cont_out = self.cont.forward([xs[0]], t, **extras)[0] + disc_out = self.disc.forward([xs[1]], t, **extras)[0] + return [cont_out, disc_out] + + self.model = CombinedModel(self.continuous_model, self.discrete_model) + + # ------------------------------------------------------------------ + # Basic initialization test + # ------------------------------------------------------------------ + def test_init(self): + solver = MultimodalSolver( + model=self.model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + source_distribution_p=self.source_p, + ) + self.assertIs(solver.model, self.model) + self.assertEqual( + solver.modality_configs, [self.continuous_cfg, self.discrete_cfg] + ) + self.assertTrue(torch.allclose(solver.source_distribution_p, self.source_p)) + + # ------------------------------------------------------------------ + # Simple sampling test (continuous + discrete) + # ------------------------------------------------------------------ + def test_sample_basic(self): + solver = MultimodalSolver( + model=self.model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + source_distribution_p=self.source_p, + ) + # Initial states: continuous (batch=1, dim=1), discrete (batch=1, categorical) + x_cont = torch.tensor([[0.0]]) # shape (1, 1) + x_disc = torch.tensor([[0]]) # shape (1, 1) + result = solver.sample( + x_init=[x_cont, x_disc], + step_size=0.1, + time_grid=torch.tensor([0.0, 1.0]), + ) + # Continuous modality: v = 2*x, Euler step => x_final = x0 + h*2*x0 = 0 + # Discrete modality: logits always select last class => final state = vocab_size-1 + self.assertTrue(torch.allclose(result[0], torch.zeros_like(result[0]))) + self.assertTrue(torch.equal(result[1], torch.tensor([self.vocab_size - 1]))) + + # ------------------------------------------------------------------ + # Return intermediates test + # ------------------------------------------------------------------ + def test_return_intermediates(self): + solver = MultimodalSolver( + model=self.model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + source_distribution_p=self.source_p, + ) + x_cont = torch.tensor([[1.0]]) # start at 1.0 + x_disc = torch.tensor([[0]]) # start at class 0 + intermediates = solver.sample( + x_init=[x_cont, x_disc], + step_size=0.5, + time_grid=torch.tensor([0.0, 0.5, 1.0]), + return_intermediates=True, + ) + # Should return a list of two lists (one per modality) + self.assertEqual(len(intermediates), 2) + # Continuous trajectory should have three entries (including start & end) + self.assertEqual(len(intermediates[0]), 3) + # Discrete trajectory should also have three entries + self.assertEqual(len(intermediates[1]), 3) + # Verify the final discrete state is the last class + self.assertTrue( + torch.equal(intermediates[1][-1], torch.tensor([self.vocab_size - 1])) + ) + + # ------------------------------------------------------------------ + # Gradient tracking test + # ------------------------------------------------------------------ + def test_gradient_enabled(self): + solver = MultimodalSolver( + model=self.model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + source_distribution_p=self.source_p, + ) + x_cont = torch.tensor([[2.0]], requires_grad=True) + x_disc = torch.tensor([[0]], requires_grad=False) + result = solver.sample( + x_init=[x_cont, x_disc], + step_size=0.1, + time_grid=torch.tensor([0.0, 1.0]), + enable_grad=True, + ) + # Only the continuous modality should have a gradient + loss = result[0].sum() + loss.backward() + self.assertIsNotNone(x_cont.grad) + self.assertIsNone(x_disc.grad) + + # ------------------------------------------------------------------ + # Divergence‑free term test (non‑zero) + # ------------------------------------------------------------------ + def test_divergence_free(self): + # Use a mock model that returns zero logits for the discrete modality + mock_model = MagicMock() + mock_model.forward = MagicMock() + mock_model.forward.return_value = [ + torch.zeros(1, 1), + torch.zeros(1, 1, self.vocab_size), + ] + + solver = MultimodalSolver( + model=mock_model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + source_distribution_p=self.source_p, + ) + x_cont = torch.tensor([[0.0]]) + x_disc = torch.tensor([[0]]) + # Use a constant divergence‑free term + result = solver.sample( + x_init=[x_cont, x_disc], + step_size=0.1, + div_free=0.5, + time_grid=torch.tensor([0.0, 1.0]), + ) + # With a non‑zero div_free, the solver should not raise an assertion. + # The exact numeric value is not critical; we just ensure the call succeeds. + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ------------------------------------------------------------------ + # Error handling tests + # ------------------------------------------------------------------ + def test_mismatched_initial_states(self): + solver = MultimodalSolver( + model=self.model, + modality_configs=[self.continuous_cfg, self.discrete_cfg], + ) + # Provide only one initial state instead of two + with self.assertRaises(ValueError): + solver.sample( + x_init=[torch.tensor([[0.0]])], + step_size=0.1, + time_grid=torch.tensor([0.0, 1.0]), + ) + + def test_invalid_modality_type(self): + # Create a bad config list + bad_cfg = [{"type": "unknown"}] + with self.assertRaises(ValueError): + MultimodalSolver( + model=self.model, + modality_configs=bad_cfg, + ) + + def test_missing_path_for_discrete(self): + bad_cfg = [{"type": "discrete"}] # No 'path' key + with self.assertRaises(ValueError): + MultimodalSolver( + model=self.model, + modality_configs=bad_cfg, + ) + + def test_non_callable_model(self): + with self.assertRaises(TypeError): + MultimodalSolver( + model=123, # Not callable + modality_configs=[self.continuous_cfg], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_multimodal.py b/tests/utils/test_multimodal.py new file mode 100644 index 0000000..d41770b --- /dev/null +++ b/tests/utils/test_multimodal.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. +import unittest +from unittest.mock import patch + +import torch +from flow_matching.path import AffineProbPath, MixtureDiscreteProbPath +from flow_matching.path.scheduler import CondOTScheduler, PolynomialConvexScheduler + +from flow_matching.utils.multimodal import _default_continuous_loss, Flow +from torch import nn + + +class DummyModel(nn.Module): + """Model that returns logits for discrete and scaled inputs for continuous.""" + + def __init__(self, num_classes: int = 5): + super().__init__() + self.num_classes = num_classes + + def forward(self, xs, t, **kwargs): + outputs = [] + for x in xs: + if x.dtype == torch.long: + batch = x.shape[0] + # Return random logits for discrete modality + outputs.append(torch.randn(batch, self.num_classes)) + else: + # Return a simple transformation for continuous modality + outputs.append(x * 2.0) + return outputs + + +class DummyMultimodalSolver: + """Mock solver that records arguments and returns predefined samples.""" + + def __init__(self, model, modality_configs, model_sampling_fn=None): + self.model = model + self.modality_configs = modality_configs + self.model_sampling_fn = model_sampling_fn + self.called_with = {} + + def sample(self, **kwargs): + self.called_with = kwargs + # Return a list of tensors matching the number of modalities + return [torch.tensor([1]), torch.tensor([2.0])] + + +class TestFlow(unittest.TestCase): + def setUp(self): + self.num_classes = 5 + self.discrete_path = MixtureDiscreteProbPath( + scheduler=PolynomialConvexScheduler(n=2.0) + ) + self.continuous_path = AffineProbPath(scheduler=CondOTScheduler()) + self.modalities = { + "disc": {"path": self.discrete_path}, + "cont": {"path": self.continuous_path}, + } + self.model = DummyModel(num_classes=self.num_classes) + self.flow = Flow(model=self.model, modalities=self.modalities) + + def test_init_paths_and_losses(self): + # Paths should be stored correctly + self.assertIn("disc", self.flow.paths) + self.assertIn("cont", self.flow.paths) + self.assertIs(self.flow.paths["disc"], self.discrete_path) + self.assertIs(self.flow.paths["cont"], self.continuous_path) + + # Loss functions: discrete should be MixturePathGeneralizedKL (callable) + self.assertTrue(callable(self.flow.loss_fns["disc"])) + # Continuous should use the default continuous loss + self.assertIs(self.flow.loss_fns["cont"], _default_continuous_loss) + + def test_training_loss_computation(self): + batch = 3 + # Discrete tensors (int64) + x1_disc = torch.randint(0, self.num_classes, (batch,)) + x_t_disc = torch.randint(0, self.num_classes, (batch,)) + # Continuous tensors (float32) + x1_cont = torch.randn(batch, 2) + x_t_cont = torch.randn(batch, 2) + dx_t_cont = torch.randn(batch, 2) + # Assemble inputs matching modality order (disc, cont) + x_1 = [x1_disc, x1_cont] + x_t = [x_t_disc, x_t_cont] + dx_t = [None, dx_t_cont] + t = [torch.rand(batch), torch.rand(batch)] + + total_loss, loss_dict = self.flow.training_loss(x_1, x_t, dx_t, t) + + # Total loss should be a scalar tensor + self.assertIsInstance(total_loss, torch.Tensor) + self.assertEqual(total_loss.dim(), 0) + + # Loss dict should contain both modalities + self.assertSetEqual(set(loss_dict.keys()), {"disc", "cont"}) + # Each entry should be a scalar tensor + for loss in loss_dict.values(): + self.assertIsInstance(loss, torch.Tensor) + self.assertEqual(loss.mean().dim(), 0) + + # Total loss should equal sum of individual losses + summed = sum(loss.mean() for loss in loss_dict.values()) + self.assertTrue(torch.allclose(total_loss, summed)) + + def test_training_loss_mismatched_lengths(self): + batch = 2 + x1_disc = torch.randint(0, self.num_classes, (batch,)) + x_t_disc = torch.randint(0, self.num_classes, (batch,)) + # x1_cont = torch.randn(batch, 2) + # x_t_cont = torch.randn(batch, 2) + # dx_t_cont = torch.randn(batch, 2) + + # Omit the continuous modality to trigger assertion + x_1 = [x1_disc] + x_t = [x_t_disc] + dx_t = [None] + t = [torch.rand(batch)] + + with self.assertRaises(AssertionError): + self.flow.training_loss(x_1, x_t, dx_t, t) + + def test_sample_dtype_validation_and_output(self): + batch = 4 + # Correct dtypes + x_init_disc = torch.randint(0, self.num_classes, (batch,)) + x_init_cont = torch.randn(batch, 2) + + with patch( + "flow_matching.utils.multimodal.MultimodalSolver", + DummyMultimodalSolver, + ): + self.flow = Flow( + model=self.model, modalities=self.modalities + ) # Reinitialize to use dummy solver + samples = self.flow.sample([x_init_disc, x_init_cont], steps=5) + + # Should receive the dummy solver's output + self.assertEqual(len(samples), 2) + self.assertTrue(torch.equal(samples[0], torch.tensor([1]))) + self.assertTrue(torch.equal(samples[1], torch.tensor([2.0]))) + + def test_sample_wrong_dtype_raises(self): + batch = 3 + # Wrong dtype for discrete modality (float instead of long) + x_init_disc = torch.randn(batch, dtype=torch.float32) + x_init_cont = torch.randn(batch, 2) + + with self.assertRaises(AssertionError): + self.flow.sample([x_init_disc, x_init_cont], steps=5) + + def test_custom_loss_weights(self): + # Define modalities with custom loss weights + modalities = { + "disc": {"path": self.discrete_path, "weight": 0.5}, + "cont": {"path": self.continuous_path, "weight": 2.0}, + } + flow = Flow(model=self.model, modalities=modalities) + + # Prepare inputs + batch = 3 + x1_disc = torch.randint(0, self.num_classes, (batch,)) + x_t_disc = torch.randint(0, self.num_classes, (batch,)) + x1_cont = torch.randn(batch, 2) + x_t_cont = torch.randn(batch, 2) + dx_t_cont = torch.randn(batch, 2) + x_1 = [x1_disc, x1_cont] + x_t = [x_t_disc, x_t_cont] + dx_t = [None, dx_t_cont] + t = [torch.rand(batch), torch.rand(batch)] + + total_loss, loss_dict = flow.training_loss(x_1, x_t, dx_t, t) + + # Compute expected weighted total loss + expected_total = loss_dict["disc"].mean() + loss_dict["cont"].mean() + self.assertTrue(torch.allclose(total_loss, expected_total)) + + # Verify that loss_weights are stored correctly + self.assertEqual(flow.loss_weights["disc"], 0.5) + self.assertEqual(flow.loss_weights["cont"], 2.0) + + def test_training_loss_x1_prediction_true(self): + # Define a custom continuous loss that returns the target tensor. + def custom_continuous_loss(pred, target, reduction="none"): + # Return the target directly to verify it's used. + return target + + # Set up modalities with x_1_prediction enabled for the continuous path. + modalities = { + "disc": {"path": self.discrete_path}, + "cont": { + "path": self.continuous_path, + "loss": custom_continuous_loss, + "x_1_prediction": True, + }, + } + flow = Flow(model=self.model, modalities=modalities) + + # Prepare inputs. + batch = 3 + x1_disc = torch.randint(0, self.num_classes, (batch,)) + x_t_disc = torch.randint(0, self.num_classes, (batch,)) + x1_cont = torch.randn(batch, 2) + x_t_cont = torch.randn(batch, 2) + dx_t_cont = torch.randn( + batch, 2 + ) # Should be ignored due to x_1_prediction=True + x_1 = [x1_disc, x1_cont] + x_t = [x_t_disc, x_t_cont] + dx_t = [None, dx_t_cont] + t = [torch.rand(batch), torch.rand(batch)] + + total_loss, loss_dict = flow.training_loss(x_1, x_t, dx_t, t) + + # The continuous loss should have used x1_cont as the target. + self.assertTrue(torch.allclose(loss_dict["cont"], x1_cont)) + # Total loss should be sum of discrete loss mean and x1_cont mean. + expected_total = loss_dict["disc"].mean() + loss_dict["cont"].mean() + self.assertTrue(torch.allclose(total_loss, expected_total)) + + def test_training_loss_with_logits_argument(self): + batch = 3 + # Discrete tensors (int64) + x1_disc = torch.randint(0, self.num_classes, (batch,)) + x_t_disc = torch.randint(0, self.num_classes, (batch,)) + # Continuous tensors (float32) + x1_cont = torch.randn(batch, 2) + x_t_cont = torch.randn(batch, 2) + dx_t_cont = torch.randn(batch, 2) + x_1 = [x1_disc, x1_cont] + x_t = [x_t_disc, x_t_cont] + dx_t = [None, dx_t_cont] + t = [torch.rand(batch), torch.rand(batch)] + + # Deterministic logits for discrete and continuous modalities + logits_disc = torch.full((batch, self.num_classes), 0.5) + logits_cont = torch.full_like(dx_t_cont, 0.1) + logits = [logits_disc, logits_cont] + + # Ensure model forward is not called when logits are provided + with patch.object( + self.flow.model, + "forward", + side_effect=AssertionError("Model forward should not be called"), + ): + total_loss, loss_dict = self.flow.training_loss( + x_1, x_t, dx_t, t, model_output=logits + ) + + # Verify total loss is scalar and matches sum of individual losses + self.assertIsInstance(total_loss, torch.Tensor) + self.assertEqual(total_loss.dim(), 0) + self.assertSetEqual(set(loss_dict.keys()), {"disc", "cont"}) + summed = sum(loss.mean() for loss in loss_dict.values()) + self.assertTrue(torch.allclose(total_loss, summed)) + + +if __name__ == "__main__": + unittest.main()