Git Product home page Git Product logo

tensorf-jax's Introduction

tensorf-jax

JAX implementation of Tensorial Radiance Fields, written as an exercise.

@misc{TensoRF,
      title={TensoRF: Tensorial Radiance Fields},
      author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
      year={2022},
      eprint={2203.09517},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

We don't attempt to reproduce the original paper exactly, but can achieve decent results after 5~10 minutes of training:

Lego rendering GIF

As proposed, TensoRF only supports scenes that fit in a fixed-size bounding box. We've also added basic support for unbounded "real" scenes via mip-NeRF 360-inspired scene contraction1. From nerfstudio's "dozer" dataset:

Dozer rendering GIF

Instructions

  1. Download nerf_synthetic dataset: Google Drive. With the default training script arguments, we expect this to be extracted to ./data, eg ./data/nerf_synthetic/lego.

  2. Install dependencies. Probably you want the GPU version of JAX; see the official instructions. Then:

    pip install -r requirements.txt
  3. To print training options:

    python ./train_lego.py --help
  4. To monitor training, we use Tensorboard:

    tensorboard --logdir=./runs/
  5. To render:

    python ./render_360.py --help

Differences from the PyTorch implementation

Things aren't totally matched to the official implementation:

  • The official implementation relies heavily on masking operations to improve runtime (for example, by using a weight threshold for sampled points). These require dynamic shapes and are currently difficult to implement in JAX, so we replace them with workarounds like weighted sampling.
  • Several training details that would likely improve performance are not yet implemented: bounding box refinement, ray filtering, regularization, etc.
  • We include mixed-precision training, which can speed training throughput up by a significant factor. (is this actually faster in terms of wall-clock time? unclear)

References

Implementation details are based loosely on the original PyTorch implementation apchsenstu/TensoRF.

unixpickle/learn-nerf and google-research/jaxnerf were also really helpful for understanding core NeRF concepts + connecting them to JAX!

To-do

  • Main implementation
    • Point sampling
    • Feature MLP
    • Rendering
    • VM decomposition
      • Basic implementation
      • Vectorized
    • Dataloading
      • Blender
      • nerfstudio
        • Basics
        • Fisheye support
        • Compute samples without undistorting images (throws away a lot of pixels)
    • Tricks for real data
      • Scene contraction (~mip-NeRF 360)
      • Camera embeddings
  • Training
    • Learning rate scheduler
      • ADAM + grouped LR
      • Exponential decay
      • Reset decay after upsampling
    • Running
    • Checkpointing
    • Logging
      • Loss
      • PSNR
      • Test metrics
      • Test images
      • Render previews
    • Ray filtering
    • Bounding box refinement
    • Incremental upsampling
    • Regularization terms
  • Performance
    • Weight thresholding for computing appearance features
      • per ray top-k
      • global top-k (bad & deleted)
    • Mixed-precision
      • implemented
      • stable
    • Multi-GPU (should be quick)
  • Rendering
    • RGB
    • Depth (median)
    • Depth (mean)
    • Batching
    • Generate some GIFs
  • Misc engineering
    • Actions
    • Understand vmap performance differences (details)

Footnotes

  1. Same as the original, but with an $L-\infty$ norm instead of $L-2$ norm.

tensorf-jax's People

Contributors

brentyi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.