Git Product home page Git Product logo

Comments (3)

YuxuanSnow avatar YuxuanSnow commented on June 29, 2024

I believe it's caused by the different camera coordinate system between NerF dataset and PyTorch3d.
You can convert to PyTorch3D coordinate w.r.t this figure
image

Alternatively you can render a texture mesh (the cow) into two views and test your RGB-D warping function.

from pytorch3d.

guochengqian avatar guochengqian commented on June 29, 2024

@YuxuanSnow Hi, thanks for the comments. I use the preprocessed data from PyTorch3D, which should already did the coordinate convention. Anyway, I will test what you mentioned.

from pytorch3d.

guochengqian avatar guochengqian commented on June 29, 2024

hi @YuxuanSnow thanks for the information. It works now. The issue is with the PyTorch3D's preprocessed NeRF_synthetic data. It resets the translation to 0 somehow. I thus moved to using rendering from mesh, and do the same image wrapping and it worked. The code is here if anyone else is interested.

# coding: utf-8

# In[ ]:


# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.


# # Render a textured mesh
# 
# This tutorial shows how to:
# - load a mesh and textures from an `.obj` file. 
# - set up a renderer 
# - render the mesh 
# - vary the rendering settings such as lighting and camera position
# - use the batching features of the pytorch3d API to render the mesh from different viewpoints

# ## 0. Install and Import modules

# Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:

# In[ ]:


import os
import sys
import torch

# In[ ]:


import os
import torch
import matplotlib.pyplot as plt

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRendererWithFragments, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesUV,
    TexturesVertex
)

# add path for demo utils functions 
import sys
import os

# OR if running **locally** uncomment and run the following cell:

# In[ ]:


# ### 1. Load a mesh and texture file
# 
# Load an `.obj` file and its associated `.mtl` file and create a **Textures** and **Meshes** object. 
# 
# **Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. 
# 
# **TexturesUV** is an auxiliary datastructure for storing vertex uv and texture maps for meshes. 
# 
# **Meshes** has several class methods which are used throughout the rendering pipeline.

# If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path `data/cow_mesh`:
# If running locally, the data is already available at the correct path. 

# In[ ]:


# get_ipython().system('mkdir -p data/cow_mesh')
# get_ipython().system('wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj')
# get_ipython().system('wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl')
# get_ipython().system('wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png')


# In[ ]:


# Setup
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Set paths
DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")
save_dir = 'debug/cow/'
os.makedirs(save_dir, exist_ok=True)
image_size = [64, 64]

# Load obj file
mesh = load_objs_as_meshes([obj_filename], device=device)


# # #### Let's visualize the texture map

# # In[ ]:

# #%%
# plt.figure(figsize=(7,7))
# texture_image=mesh.textures.maps_padded()
# plt.imshow(texture_image.squeeze().cpu().numpy())
# plt.axis("off");


# # PyTorch3D has a built-in way to view the texture map with matplotlib along with the points on the map corresponding to vertices. There is also a method, texturesuv_image_PIL, to get a similar image which can be saved to a file.

# # In[ ]:


# plt.figure(figsize=(7,7))
# texturesuv_image_matplotlib(mesh.textures, subsample=None)
# plt.axis("off");


# ## 2. Create a renderer
# 
# A renderer in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthographic/perspective). Here we initialize some of these components and use default values for the rest.
# 
# In this example we will first create a **renderer** which uses a **perspective camera**, a **point light** and applies **Phong shading**. Then we learn how to vary different components using the modular API.  

# In[ ]:


# Initialize a camera.
# With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. 
# So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. 
# TODO: how about znear and zfar here.
R1, T1 = look_at_view_transform(2.7, 0, 150) 
cameras1 = FoVPerspectiveCameras(device=device, R=R1, T=T1)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=image_size, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model

# TODO: re-read all CG basics later. 
renderer = MeshRendererWithFragments(
    rasterizer=MeshRasterizer(
        cameras=cameras1, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras1,
        lights=lights
    )
)


# # Change specular color to green and change material shininess 
# materials = Materials(
#     device=device,
#     specular_color=[[1.0, 1.0, 1.0]],
#     shininess=10.0
# )


# ## 3. Render the mesh

# The light is in front of the object so it is bright and the image has specular highlights.

# In[ ]:


image1, fragment1 = renderer(mesh, lights=lights, 
                            #  materials=materials, 
                             cameras=cameras1)
plt.figure(figsize=(10, 10))
plt.imshow(image1[0, ..., :3].cpu().numpy())
plt.savefig(os.path.join(save_dir, 'image1.jpg'))
plt.axis("off")

# TODO: render depth here.
# TODO: what does -1 mean in zbuf. 
depth1  = fragment1.zbuf
depth1[depth1<0] = 0

plt.figure(figsize=(10, 10))
plt.imshow(depth1[0, ..., :3].cpu().numpy())
plt.savefig(os.path.join(save_dir, 'depth1.jpg'))
plt.axis("off")


# Rotate the object by increasing the elevation and azimuth angles
R2, T2 = look_at_view_transform(dist=2.7, elev=10, azim=-45)
cameras2 = FoVPerspectiveCameras(device=device, R=R2, T=T2)

# Move the light location so the light is shining on the cow's face.  
# lights.location = torch.tensor([[2.0, 2.0, -2.0]], device=device)


# Re render the mesh, passing in keyword arguments for the modified components.
image2, fragment2 = renderer(mesh, 
                            lights=lights, 
                            #  materials=materials, 
                             cameras=cameras2)


# In[ ]:


plt.figure(figsize=(10, 10))
plt.imshow(image2[0, ..., :3].cpu().numpy())
plt.axis("off");
plt.savefig(os.path.join(save_dir, 'image2.jpg'))


# now wrap camera1 to camera2
# TODO: normalized depth or not?
from pytorch3d.renderer import (
    AlphaCompositor,
    NDCMultinomialRaysampler,
    PointsRasterizationSettings,
    PointsRasterizer,
    ray_bundle_to_ray_points,
)
from pytorch3d.structures import Pointclouds


# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
    NDCMultinomialRaysampler(
        image_width=image_size[0],
        image_height=image_size[1],
        n_pts_per_ray=1,
        min_depth=1.0,
        max_depth=1.0,
        unit_directions=False,
    )(cameras1)._replace(lengths=depth1)
)
pts_3d = pts_3d.reshape(-1, 3)
# pts_mask = depth > 0.0
# pts_mask = pts_mask.reshape(-1)

# pts_3d_filtered = pts_3d.reshape(-1, 3)[pts_mask]

# check camera center of two transformations
point_cloud = Pointclouds(points=pts_3d[None], features=image1[..., :3].reshape(1, -1, 3))
# from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from typing import cast, Optional, Tuple
import torch.nn.functional as Fu

def render_point_cloud_pytorch3d(
    camera,
    point_cloud,
    render_size: Tuple[int, int],
    point_radius: float = 0.03,
    topk: int = 10,
    eps: float = 1e-2,
    bg_color=None,
    bin_size: Optional[int] = None,
    **kwargs,
):

    # feature dimension
    featdim = point_cloud.features_packed().shape[-1]

# This code snippet is transforming the points in the point cloud to camera coordinates. It first
# calls the `_transform_points` function to transform the points using the camera transformation.
# Then, it creates a new camera object called `camera_trivial` with identity rotation and zero
# translation. This camera is used in the renderer to render the point cloud from the camera's
# perspective.
    # # move to the camera coordinates; using identity cameras in the renderer
    # point_cloud = _transform_points(camera, point_cloud, eps, **kwargs)
    # camera_trivial = camera.clone()
    # camera_trivial.R[:] = torch.eye(3)
    # camera_trivial.T *= 0.0

    bin_size = (
        bin_size
        if bin_size is not None
        else (64 if int(max(render_size)) > 1024 else None)
    )
    rasterizer = PointsRasterizer(
        # The line `# cameras=camera_trivial,` is commented out in the code. It is used to specify the
        # camera object to be used in the PointsRasterizer. In this case, the `camera_trivial` object
        # is used, which is a camera with identity rotation and zero translation. This means that the
        # point cloud will be rendered from the camera's perspective without any rotation or
        # translation applied. However, since this line is commented out, the original camera object
        # (`camera`) is used instead.
        # cameras=camera_trivial,
        cameras=camera,
        raster_settings=PointsRasterizationSettings(
            image_size=render_size,
            radius=point_radius,
            points_per_pixel=topk,
            bin_size=bin_size,
        ),
    )

    fragments = rasterizer(point_cloud, **kwargs)

    # Construct weights based on the distance of a point to the true point.
    # However, this could be done differently: e.g. predicted as opposed
    # to a function of the weights.
    r = rasterizer.raster_settings.radius

    # set up the blending weights
    dists2 = fragments.dists
    weights = 1 - dists2 / (r * r)
    ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()

    weights = weights * ok

    fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
    weights_prm = weights.permute(0, 3, 1, 2)
    images = AlphaCompositor()(
        fragments_prm,
        weights_prm,
        point_cloud.features_packed().permute(1, 0),
        background_color=bg_color if bg_color is not None else [0.0] * featdim,
        **kwargs,
    )

    # get the depths ...
    # weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
    # cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
    cumprod = torch.cumprod(1 - weights, dim=-1)
    cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
    depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
    # add the rendering mask
    # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
    render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0

    # cat depths and render mask
    rendered_blob = torch.cat((images, depths[:, None], render_mask[:, None]), dim=1)

    # reshape back
    rendered_blob = Fu.interpolate(
        rendered_blob,
        size=tuple(render_size),
        mode="bilinear",
        align_corners=False,
    )

    data_rendered, depth_rendered, render_mask = rendered_blob.split(
        [rendered_blob.shape[1] - 2, 1, 1],
        dim=1,
    )

    return data_rendered, render_mask, depth_rendered

data_rendered, render_mask, depth_rendered = render_point_cloud_pytorch3d(cameras1, point_cloud, 
                                                                          render_size=image_size, 
                                                                          point_radius=0.03, 
                                                                          topk=1, eps=1e-2, bg_color=None, bin_size=None)
print(data_rendered.shape, render_mask.shape, depth_rendered.shape)

fig = plt.figure(figsize=(10, 10))
plt.imshow(data_rendered[0].cpu().numpy().transpose(1, 2, 0))
plt.show()
plt.savefig(os.path.join(save_dir, 'rendering1.jpg'))
plt.axis("off")


data_rendered2, render_mask2, depth_rendered2 = render_point_cloud_pytorch3d(cameras2, point_cloud, render_size=image_size, 
                                                                          point_radius=0.03, 
                                                                          topk=1, eps=1e-2, bg_color=None, bin_size=None)
print(data_rendered2.shape, render_mask.shape, depth_rendered.shape)
fig = plt.figure(figsize=(10, 10))
plt.imshow(data_rendered2[0].cpu().numpy().transpose(1, 2, 0))
plt.axis("off")
plt.savefig(os.path.join(save_dir, 'rendering2.jpg'))

print('Done')

from pytorch3d.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.