Git Product home page Git Product logo

Comments (10)

Nik-V9 avatar Nik-V9 commented on July 29, 2024 7

Hi @small-zeng, Thanks for your interest in our work!

We would like to mention that the reported speed for the current version of SplaTAM in the paper (Table 6) is 0.5 FPS. We also provide the config for SplaTAM-S (mentioned in the paper), which is the faster variant and has ~2 FPS with similar performance. We see similar numbers as shared by you on an RTX 3080 Ti.

image

Some other factors that influence overall speed are the read and write speed of the disk and whether wandb is being used.

As mentioned by @Buffyqsf, most of our current overhead comes from PyTorch due to the rigid transformation of Gaussians through large matrix multiplications. In our experiments, we have observed that the speed starts to drastically decrease beyond a certain number of Gaussians. Hence, in this context, we have seen SplaTAM-S hold its speeds throughout the scene (because the Gaussian densification occurs at half resolution).

We believe that a CUDA implementation of the pose gradients would significantly fasten the transformation operation similar to the rasterizer. We have planned this for V2 (potentially, we expect to see numbers around 20 FPS).

from splatam.

Buffyqsf avatar Buffyqsf commented on July 29, 2024 1

@zero-joke I've tried that, but it's still very slow. I found that it's fast at first but getting slower with time passing by. And the average time is not at a level similar to paper.

from splatam.

tingxixue avatar tingxixue commented on July 29, 2024

I also encountered this issue. My graphics card is a 4070 Ti, and my FPS is around 0.25.
image

from splatam.

Boyuan-Tian avatar Boyuan-Tian commented on July 29, 2024

Same here, I got 2.47s for tracking and 3.95s for mapping at each frame on 3090.

from splatam.

Buffyqsf avatar Buffyqsf commented on July 29, 2024

I got similar problem on 3090

from splatam.

zero-joke avatar zero-joke commented on July 29, 2024

when setting use_wandb=False in 'configs//.py' , the speed will be improved to a level similar to the paper.

from splatam.

Buffyqsf avatar Buffyqsf commented on July 29, 2024

I found it's related to the transform operation. The code below in get loss takes most time. When the GS model getting larger, transforming the whole model in every iteration is not rational, can we just change the view camera to train the model? (renderer is very effective) @Nik-V9
image

image

from splatam.

kuldeepbrd1 avatar kuldeepbrd1 commented on July 29, 2024

Thanks for this clarification @Nik-V9 and your work.

I added simple backward gradient computation for pose (atm using overparameterized pose matrix elements)
graphdeco-inria/gaussian-splatting#629

Does this make sense? I have not verified on a metric yet but would love your feedback

Thanks for this great work and your comment on camera gradients in graphdeco-inria/gaussian-splatting#84 .

I tried adding backward gradients of the loss explicitly w.r.t the camera pose using an over-parameterized SE(3) pose matrix.
Would greatly appreciate it, if you could check whether this makes sense.

In __global__ void computeCov2DCUDA(...):

       
        // ---------------- Gradients w.r.t Tcw ------------
  // Gradients due to 2D covariance : 1st gradient portion from 2D covariance -> 3D means(t))
  
       // flattened_3x4_pose = {r00, r01, r02, r10, r11, r12,  r20, r21, r22, t0, t1, t2}
  // dL/dTcw = dL/dcov_c * dcov_c/dW * dW/dTcw
  
  // Loss w.r.t W:   dL/dW = dL/dT *dT/dW
  // dL/dW = J^T * dL/dT
  float dL_dW00 = J[0][0] * dL_dT00 + J[1][0] * dL_dT10;
  float dL_dW10 = J[0][0] * dL_dT01 + J[1][0] * dL_dT11;
  float dL_dW20 = J[0][0] * dL_dT02 + J[1][0] * dL_dT12;
  float dL_dW01 = J[0][1] * dL_dT00 + J[1][1] * dL_dT10;
  float dL_dW11 = J[0][1] * dL_dT01 + J[1][1] * dL_dT11;
  float dL_dW21 = J[0][1] * dL_dT02 + J[1][1] * dL_dT12;
  float dL_dW02 = J[0][2] * dL_dT00 + J[1][2] * dL_dT10;
  float dL_dW12 = J[0][2] * dL_dT01 + J[1][2] * dL_dT11;
  float dL_dW22 = J[0][2] * dL_dT02 + J[1][2] * dL_dT12;
  
  // Loss w.r.t Tcw elements
  dL_dTcw[0] += dL_dtx * m.x + dL_dW00;
  dL_dTcw[1] += dL_dtx * m.y + dL_dW10;
  dL_dTcw[2] += dL_dtx * m.z + dL_dW20;
  dL_dTcw[3] += dL_dty * m.x + dL_dW01;
  dL_dTcw[4] += dL_dty * m.y + dL_dW11;
  dL_dTcw[5] += dL_dty * m.z + dL_dW21;
  dL_dTcw[6] += dL_dtz * m.x + dL_dW02;
  dL_dTcw[7] += dL_dtz * m.y + dL_dW12;
  dL_dTcw[8] += dL_dtz * m.z + dL_dW22;
  dL_dTcw[9] += dL_dtx;
  dL_dTcw[10] += dL_dty;
  dL_dTcw[11] += dL_dtz;

And in __global__ void preprocessCUDA():

       // ---------------- Gradients w.r.t Tcw ------------
  //  Gradients for 3x4 elements due to 2D means
  // (2nd gradient portion from 2D means -> 3D means(t))
  // flattened_3x4_pose = {r00, r01, r02, r10, r11, r12,  r20, r21, r22, t0, t1, t2}
  dL_dTcw[0] += dL_dmean.x * mean.x;
  dL_dTcw[1] += dL_dmean.x * mean.y;
  dL_dTcw[2] += dL_dmean.x * mean.z;
  dL_dTcw[3] += dL_dmean.y * mean.x;
  dL_dTcw[4] += dL_dmean.y * mean.y;
  dL_dTcw[5] += dL_dmean.y * mean.z;
  dL_dTcw[6] += dL_dmean.z * mean.x;
  dL_dTcw[7] += dL_dmean.z * mean.y;
  dL_dTcw[8] += dL_dmean.z * mean.z;
  dL_dTcw[9] += dL_dmean.x;
  dL_dTcw[10] += dL_dmean.y;
  dL_dTcw[11] += dL_dmean.z;

from splatam.

kuldeepbrd1 avatar kuldeepbrd1 commented on July 29, 2024

I was able to get gradients w.r.t position and quaternion and they seem to work. Just needed to use the 3x3 gradients from $dL/dT_{cw}$ and compute Frobenius product $<\frac{\partial L}{\partial R_{cw} }, \frac{\partial R_{cw}}{\partial q_{cw}^i } >$, for each scalar element. The latter partial can be computed on top of the dL_dTcw from the earlier comment, as following :

// ---------------- Gradients w.r.t t_cw and q_cw ------------
	// q_cw is the normalized quaternion in wxyz format
	// t_cw is the translation vector
	// Take the gradient elements corresponding to the rotation
	glm::mat3 dL_dRcw = glm::mat3(dL_dTcw[0], dL_dTcw[3], dL_dTcw[6],
								  dL_dTcw[1], dL_dTcw[4], dL_dTcw[7],
								  dL_dTcw[2], dL_dTcw[5], dL_dTcw[8]);
	glm::mat3 dL_dRcwT = glm::transpose(dL_dRcw);

	float w = quat[0];
	float x = quat[1];
	float y = quat[2];
	float z = quat[3];

	// dL/dq = dL/dRcw * dRcw/dq
	glm::vec4 dL_dqcw;
	dL_dqcw.w = 2 * z * (dL_dRcwT[0][1] - dL_dRcwT[1][0]) + 2 * y * (dL_dRcwT[2][0] - dL_dRcwT[0][2]) + 2 * x * (dL_dRcwT[1][2] - dL_dRcwT[2][1]);
	dL_dqcw.x = 2 * y * (dL_dRcwT[1][0] + dL_dRcwT[0][1]) + 2 * z * (dL_dRcwT[2][0] + dL_dRcwT[0][2]) + 2 * w * (dL_dRcwT[1][2] - dL_dRcwT[2][1]) - 4 * x * (dL_dRcwT[2][2] + dL_dRcwT[1][1]);
	dL_dqcw.y = 2 * x * (dL_dRcwT[1][0] + dL_dRcwT[0][1]) + 2 * w * (dL_dRcwT[2][0] - dL_dRcwT[0][2]) + 2 * z * (dL_dRcwT[1][2] + dL_dRcwT[2][1]) - 4 * y * (dL_dRcwT[2][2] + dL_dRcwT[0][0]);
	dL_dqcw.z = 2 * w * (dL_dRcwT[0][1] - dL_dRcwT[1][0]) + 2 * x * (dL_dRcwT[2][0] + dL_dRcwT[0][2]) + 2 * y * (dL_dRcwT[1][2] + dL_dRcwT[2][1]) - 4 * z * (dL_dRcwT[1][1] + dL_dRcwT[0][0]);

	// Gradients of loss w.r.t. translation
	dL_dtq[0] += -dL_dTcw[9];
	dL_dtq[1] += -dL_dTcw[10];
	dL_dtq[2] += -dL_dTcw[11];

	// Gradients of loss w.r.t. quaternion
	dL_dtq[3] += dL_dqcw.w;
	dL_dtq[4] += dL_dqcw.x;
	dL_dtq[5] += dL_dqcw.y;
	dL_dtq[6] += dL_dqcw.z;
	

I have implemented this in the fork of the rasterization submodule here.

Would be grateful if someone could check or verify this.

from splatam.

kevintsq avatar kevintsq commented on July 29, 2024

I think it's slow partly because some tensor operations are done on the CPU rather than the GPU, so all physical CPU cores are consumed throughout the entire training process. It's a good practice to create any tensors directly on the GPU and specify their type, preventing expensive computation on the CPU and reducing the copy/casting overhead. After the modification, the training process can achieve 100% GPU utilization while consuming only 1 CPU core, being 1x faster than 52 CPU cores, which is desired as expensive operations are done on the GPU.

from splatam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.