Comments (10)
Hi @small-zeng, Thanks for your interest in our work!
We would like to mention that the reported speed for the current version of SplaTAM in the paper (Table 6) is 0.5 FPS. We also provide the config for SplaTAM-S (mentioned in the paper), which is the faster variant and has ~2 FPS with similar performance. We see similar numbers as shared by you on an RTX 3080 Ti.
Some other factors that influence overall speed are the read and write speed of the disk and whether wandb is being used.
As mentioned by @Buffyqsf, most of our current overhead comes from PyTorch due to the rigid transformation of Gaussians through large matrix multiplications. In our experiments, we have observed that the speed starts to drastically decrease beyond a certain number of Gaussians. Hence, in this context, we have seen SplaTAM-S hold its speeds throughout the scene (because the Gaussian densification occurs at half resolution).
We believe that a CUDA implementation of the pose gradients would significantly fasten the transformation operation similar to the rasterizer. We have planned this for V2 (potentially, we expect to see numbers around 20 FPS).
from splatam.
@zero-joke I've tried that, but it's still very slow. I found that it's fast at first but getting slower with time passing by. And the average time is not at a level similar to paper.
from splatam.
I also encountered this issue. My graphics card is a 4070 Ti, and my FPS is around 0.25.
from splatam.
Same here, I got 2.47s for tracking and 3.95s for mapping at each frame on 3090.
from splatam.
I got similar problem on 3090
from splatam.
when setting use_wandb=False in 'configs//.py' , the speed will be improved to a level similar to the paper.
from splatam.
I found it's related to the transform operation. The code below in get loss takes most time. When the GS model getting larger, transforming the whole model in every iteration is not rational, can we just change the view camera to train the model? (renderer is very effective) @Nik-V9
from splatam.
Thanks for this clarification @Nik-V9 and your work.
I added simple backward gradient computation for pose (atm using overparameterized pose matrix elements)
graphdeco-inria/gaussian-splatting#629
Does this make sense? I have not verified on a metric yet but would love your feedback
Thanks for this great work and your comment on camera gradients in graphdeco-inria/gaussian-splatting#84 .
I tried adding backward gradients of the loss explicitly w.r.t the camera pose using an over-parameterized SE(3) pose matrix.
Would greatly appreciate it, if you could check whether this makes sense.In
__global__ void computeCov2DCUDA(...)
:// ---------------- Gradients w.r.t Tcw ------------ // Gradients due to 2D covariance : 1st gradient portion from 2D covariance -> 3D means(t)) // flattened_3x4_pose = {r00, r01, r02, r10, r11, r12, r20, r21, r22, t0, t1, t2} // dL/dTcw = dL/dcov_c * dcov_c/dW * dW/dTcw // Loss w.r.t W: dL/dW = dL/dT *dT/dW // dL/dW = J^T * dL/dT float dL_dW00 = J[0][0] * dL_dT00 + J[1][0] * dL_dT10; float dL_dW10 = J[0][0] * dL_dT01 + J[1][0] * dL_dT11; float dL_dW20 = J[0][0] * dL_dT02 + J[1][0] * dL_dT12; float dL_dW01 = J[0][1] * dL_dT00 + J[1][1] * dL_dT10; float dL_dW11 = J[0][1] * dL_dT01 + J[1][1] * dL_dT11; float dL_dW21 = J[0][1] * dL_dT02 + J[1][1] * dL_dT12; float dL_dW02 = J[0][2] * dL_dT00 + J[1][2] * dL_dT10; float dL_dW12 = J[0][2] * dL_dT01 + J[1][2] * dL_dT11; float dL_dW22 = J[0][2] * dL_dT02 + J[1][2] * dL_dT12; // Loss w.r.t Tcw elements dL_dTcw[0] += dL_dtx * m.x + dL_dW00; dL_dTcw[1] += dL_dtx * m.y + dL_dW10; dL_dTcw[2] += dL_dtx * m.z + dL_dW20; dL_dTcw[3] += dL_dty * m.x + dL_dW01; dL_dTcw[4] += dL_dty * m.y + dL_dW11; dL_dTcw[5] += dL_dty * m.z + dL_dW21; dL_dTcw[6] += dL_dtz * m.x + dL_dW02; dL_dTcw[7] += dL_dtz * m.y + dL_dW12; dL_dTcw[8] += dL_dtz * m.z + dL_dW22; dL_dTcw[9] += dL_dtx; dL_dTcw[10] += dL_dty; dL_dTcw[11] += dL_dtz;And in
__global__ void preprocessCUDA()
:// ---------------- Gradients w.r.t Tcw ------------ // Gradients for 3x4 elements due to 2D means // (2nd gradient portion from 2D means -> 3D means(t)) // flattened_3x4_pose = {r00, r01, r02, r10, r11, r12, r20, r21, r22, t0, t1, t2} dL_dTcw[0] += dL_dmean.x * mean.x; dL_dTcw[1] += dL_dmean.x * mean.y; dL_dTcw[2] += dL_dmean.x * mean.z; dL_dTcw[3] += dL_dmean.y * mean.x; dL_dTcw[4] += dL_dmean.y * mean.y; dL_dTcw[5] += dL_dmean.y * mean.z; dL_dTcw[6] += dL_dmean.z * mean.x; dL_dTcw[7] += dL_dmean.z * mean.y; dL_dTcw[8] += dL_dmean.z * mean.z; dL_dTcw[9] += dL_dmean.x; dL_dTcw[10] += dL_dmean.y; dL_dTcw[11] += dL_dmean.z;
from splatam.
I was able to get gradients w.r.t position and quaternion and they seem to work. Just needed to use the 3x3 gradients from dL_dTcw
from the earlier comment, as following :
// ---------------- Gradients w.r.t t_cw and q_cw ------------
// q_cw is the normalized quaternion in wxyz format
// t_cw is the translation vector
// Take the gradient elements corresponding to the rotation
glm::mat3 dL_dRcw = glm::mat3(dL_dTcw[0], dL_dTcw[3], dL_dTcw[6],
dL_dTcw[1], dL_dTcw[4], dL_dTcw[7],
dL_dTcw[2], dL_dTcw[5], dL_dTcw[8]);
glm::mat3 dL_dRcwT = glm::transpose(dL_dRcw);
float w = quat[0];
float x = quat[1];
float y = quat[2];
float z = quat[3];
// dL/dq = dL/dRcw * dRcw/dq
glm::vec4 dL_dqcw;
dL_dqcw.w = 2 * z * (dL_dRcwT[0][1] - dL_dRcwT[1][0]) + 2 * y * (dL_dRcwT[2][0] - dL_dRcwT[0][2]) + 2 * x * (dL_dRcwT[1][2] - dL_dRcwT[2][1]);
dL_dqcw.x = 2 * y * (dL_dRcwT[1][0] + dL_dRcwT[0][1]) + 2 * z * (dL_dRcwT[2][0] + dL_dRcwT[0][2]) + 2 * w * (dL_dRcwT[1][2] - dL_dRcwT[2][1]) - 4 * x * (dL_dRcwT[2][2] + dL_dRcwT[1][1]);
dL_dqcw.y = 2 * x * (dL_dRcwT[1][0] + dL_dRcwT[0][1]) + 2 * w * (dL_dRcwT[2][0] - dL_dRcwT[0][2]) + 2 * z * (dL_dRcwT[1][2] + dL_dRcwT[2][1]) - 4 * y * (dL_dRcwT[2][2] + dL_dRcwT[0][0]);
dL_dqcw.z = 2 * w * (dL_dRcwT[0][1] - dL_dRcwT[1][0]) + 2 * x * (dL_dRcwT[2][0] + dL_dRcwT[0][2]) + 2 * y * (dL_dRcwT[1][2] + dL_dRcwT[2][1]) - 4 * z * (dL_dRcwT[1][1] + dL_dRcwT[0][0]);
// Gradients of loss w.r.t. translation
dL_dtq[0] += -dL_dTcw[9];
dL_dtq[1] += -dL_dTcw[10];
dL_dtq[2] += -dL_dTcw[11];
// Gradients of loss w.r.t. quaternion
dL_dtq[3] += dL_dqcw.w;
dL_dtq[4] += dL_dqcw.x;
dL_dtq[5] += dL_dqcw.y;
dL_dtq[6] += dL_dqcw.z;
I have implemented this in the fork of the rasterization submodule here.
Would be grateful if someone could check or verify this.
from splatam.
I think it's slow partly because some tensor operations are done on the CPU rather than the GPU, so all physical CPU cores are consumed throughout the entire training process. It's a good practice to create any tensors directly on the GPU and specify their type, preventing expensive computation on the CPU and reducing the copy/casting overhead. After the modification, the training process can achieve 100% GPU utilization while consuming only 1 CPU core, being 1x faster than 52 CPU cores, which is desired as expensive operations are done on the GPU.
from splatam.
Related Issues (20)
- How to create a new dataset HOT 1
- Query regarding constant camera intrinsics HOT 8
- Type of Loss Function HOT 2
- GPU memory during training HOT 1
- Query regarding first camera pose HOT 2
- Query regarding Scannet dataset config HOT 2
- Query regarding data loading for custom dataset HOT 2
- Question about the evaluation metrics
- Query regarding Replica Dataset Version & Performance HOT 3
- Question regarding number of frames tracked during SLAM HOT 1
- Inquiry Regarding Runtime Discrepancies in Splatam HOT 4
- Why is the projmatrix in cam the same to first frame? HOT 3
- diff-gaussian-rasterization error HOT 2
- PSNR are not good.
- Video was taken for map reconstruction
- Clarification Needed on Code Questions
- Confusing about Trajectory Evaluation Function evaluate_ate
- CUDA OUT OF MEMORY
- D435i real-time or offline slam?
- Question regarding the shape of the mask created from the `depth`.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from splatam.