Skip to content

Backprop RGB image losses to mesh shape? #839

@priyasundaresan

Description

@priyasundaresan

Backprop RGB image losses to mesh shape?

I am trying to fit a source mesh (sphere) to a target mesh (cow) per the example from https://pytorch3d.org/tutorials/fit_textured_mesh. In particular, I would like to propagate losses taken over the rendered RGB images of the current and target mesh to the vertex positions of the current mesh being deformed.

I am able to achieve the desired results using a 50-50 weighting of L1 silhouette loss and L1 RGB loss taken over the rendered images, and the training progression across 200 iterations is shown here:
00000_depth
00050_depth
00100_depth
00150_depth
00200_depth

However, using only L1 RGB loss, the mesh doesn't converge to the desired shape as shown here:
00000_depth
00050_depth
00100_depth
00150_depth

I have tried using L2 RGB loss and changing the texture color, but still have this issue. Is it possible to use RGB image supervision without silhouettes and propagate to mesh shape? I have referred to the similar issue here but have confirmed that the rasterization settings do not have this problem.

This can be reproduced by running the following code, and replacing losses = {"rgb": {"weight": 0.5, "values": []}, "silhouette": {"weight": 0.5, "values": []}} with losses = {"rgb": {"weight": 1.0, "values": []}, "silhouette": {"weight": 0.0, "values": []}}

import os
import cv2
import sys
sys.path.append(os.path.abspath(''))
import torch
import os
import torch
import matplotlib.pyplot as plt
from pytorch3d.utils import ico_sphere
import numpy as np
from pytorch3d.io import load_objs_as_meshes, save_obj
from plot_image_grid import image_grid
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")
mesh = load_objs_as_meshes([obj_filename], device=device)
white_tex = torch.ones_like(mesh.verts_packed())
white_tex = white_tex.unsqueeze(0) 
white_tex = TexturesVertex(verts_features=white_tex.to(device))
mesh.textures = white_tex

verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center)
mesh.scale_verts_((1.0 / float(scale)));

num_views = 20
#num_views = 2
elev = torch.linspace(0, 360, num_views)
azim = torch.linspace(-180, 180, num_views)
#azim = torch.linspace(-180, -90, num_views)
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
camera = cameras[0]

raster_settings = RasterizationSettings(
    image_size=128, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
    perspective_correct=False
)
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

sigma = 1e-4
raster_settings_silhouette = RasterizationSettings(
    image_size=128, 
    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
    faces_per_pixel=50, 
    perspective_correct=False
)

# Silhouette renderer 
renderer_silhouette = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings_silhouette
    ),
    shader=SoftSilhouetteShader()
)

meshes = mesh.extend(num_views)
target_images = renderer(meshes, cameras=cameras, lights=lights)
target_rgb = [target_images[i, ..., :3] for i in range(num_views)]
target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...], 
                                           T=T[None, i, ...]) for i in range(num_views)]
image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)
plt.show()
plt.savefig('cows_rgb.png')
plt.clf()

# Show a visualization comparing the rendered predicted mesh to the ground truth 
def visualize_pred(curr_image, ref_image, fname):
    visualization = np.hstack((curr_image.detach().cpu().numpy(), ref_image.detach().cpu().numpy()))
    cv2.imwrite('%s'%(fname), visualization*255)

# Plot losses as a function of optimization iteration
def plot_losses(losses):
    fig = plt.figure(figsize=(13, 5))
    ax = fig.gca()
    for k, l in losses.items():
        ax.plot(l['values'], label=k + " loss")
    ax.legend(fontsize="16")
    ax.set_xlabel("Iteration", fontsize="16")
    ax.set_ylabel("Loss", fontsize="16")
    ax.set_title("Loss vs iterations", fontsize="16")

# We initialize the source shape to be a sphere of radius 1.  
src_mesh = ico_sphere(4, device)
white_tex = torch.ones_like(src_mesh.verts_packed())
white_tex = white_tex.unsqueeze(0) 
white_tex = TexturesVertex(verts_features=white_tex.to(device))
src_mesh.textures = white_tex


# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable 
# Renderer for Image-based 3D Reasoning', ICCV 2019
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
    image_size=128, 
    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
    faces_per_pixel=50, 
    perspective_correct=False
)

# Depth rasterizer 
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings_soft
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

num_views_per_iteration = 5 # use 5 which works best
Niter = 350
plot_period = 50

losses = {"rgb": {"weight": 0.5, "values": []},
          "silhouette": {"weight": 0.5, "values": []}}
verts_shape = src_mesh.verts_packed().shape
deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)

# The optimizer
optimizer = torch.optim.Adam([deform_verts], lr=0.01) 

loop = range(Niter)

out_dir = os.path.join('out')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    # Losses to smooth /regularize the mesh shape
    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    
    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:
        curr_image = renderer(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., :3]
        curr_sil = renderer_silhouette(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., 3]
        ref_image = renderer(meshes[j], cameras=target_cameras[j], lights=lights)[..., :3]
        ref_sil = renderer_silhouette(meshes[j], cameras=target_cameras[j], lights=lights)[..., 3]
        loss_rgb = (torch.abs(curr_image - ref_image)).mean() # L1 rgb loss
        loss_silhouette = (torch.abs(curr_sil - ref_sil)).mean() # L1 silhouette loss
        loss["silhouette"] += loss_silhouette / num_views_per_iteration
        loss["rgb"] += loss_rgb / num_views_per_iteration
    
    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(float(l.detach().cpu()))

    # Print the losses
    #loop.set_description("total_loss = %.6f" % sum_loss)
    print("iter: %d/%d, total_loss = %.6f" % (i,Niter,sum_loss), end='\r')
    sys.stdout.flush()
    
    # Plot mesh
    if i % plot_period == 0:
        #visualize_prediction(new_src_mesh, fname="%s/%05d.png" % (out_dir, i), silhouette=True, target_image=target_silhouette[1])
        visualize_pred(curr_image[0], ref_image[0], fname="%s/%05d_depth.png"%(out_dir, i))
    # Optimization step
    sum_loss.backward()
    optimizer.step()

#visualize_prediction(new_src_mesh, silhouette=True, target_image=target_silhouette[1], fname='preds.png')
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)
final_verts = final_verts * scale + center
final_obj = os.path.join('./', 'cow_sil_model.obj')
save_obj(final_obj, final_verts, final_faces)
plot_losses(losses)

The cow mesh data is available by running:

!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png

Thanks in advance for your help!

Metadata

Metadata

Assignees

Labels

how toHow to use PyTorch3D in my project

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions