Git Product home page Git Product logo

Comments (4)

mmacklin avatar mmacklin commented on August 28, 2024

Hi @johannespitz, in general you can use the adjoint=True flag to manually invoke the backward version of a kernel. This is what the wp.Tape object does, but it also takes care of a few other complications like tracking launches, zeroing gradients, etc.

When you run a backward pass it does accumulate gradients (always adds to existing arrays), this is similar to PyTorch, but it means that indeed you need to make sure they are zero'd somewhere between optimization steps.

I don't think you should need to call wp.synchronize() explicitly. @nvlukasz can you confirm?

from warp.

johannespitz avatar johannespitz commented on August 28, 2024

Thank's for the reply @mmacklin!
And I would be very interested to hear if/where we really need to call wp.synchronize() @nvlukasz.

Regarding the accumulation of gradients. When we use wp.from_torch() directly, as it is used in

ctx.joint_q = wp.from_torch(joint_q)

instead of creating a new pytorch tensor with .clone() the gradient of leaf nodes in the computation graph will be 2x the true gradient, even when we clear all gradients before the call.
That is because torch expects torch.autograd.Function's to return the gradient and not write it directly into the buffer. Therefore, torch then adds the returned gradient to the gradient that warp already wrote into the buffer (for leafs in the computation graph). Note for intermediate nodes it works only because usually (if retain_graph=False) the gradient buffers of those tensors are not used at all.

from warp.

nvlukasz avatar nvlukasz commented on August 28, 2024

CUDA synchronization can be a little tricky, especially when launching work using multiple frameworks that use different scheduling mechanisms under the hood.

Short answer: If you're not explicitly creating and using custom CUDA streams in PyTorch or Warp, and both are targeting the same device, then synchronization is not necessary.

Long answer: By default, PyTorch uses the legacy default stream on each device. This stream is synchronous with respect to other blocking streams on the device, so no explicit synchonization is needed. Warp, by default, uses a blocking stream on each device, so Warp operations will automatically synchronize with PyTorch operations on the same device.

The picture changes if you start using custom streams in PyTorch. Those streams will not automatically synchronize with Warp streams, so manual synchronization will be required. This can be done using wp.synchronize(), wp.synchronize_device(), or wp.synchronize_stream(). These functions synchronize the host with outstanding GPU work, so launching new work will be done after prior work completes. We also support event-based device-side synchronization, which is generally faster because it doesn't sync the host and only ensures that the operations are synchronized on the device. This includes wp.wait_stream() and wp.wait_event(), as well as interop utilities like wp.stream_from_torch() and wp.stream_to_torch().

Note that when capturing CUDA graphs using PyTorch, a non-default stream is used, so synchronization becomes important.

Things can get a little complicated with multi-stream usage and graph capture, so we're working on extended documentation in this area! But in your simple example, the explicit synchronization shouldn't be necessary.

from warp.

johannespitz avatar johannespitz commented on August 28, 2024

Thank you for the detailed answer regarding the synchronization! @nvlukasz

Though, can either of you comment on the accumulation of the gradients again. @mmacklin
Am I missing something, or is the example code incorrect at the moment?
(Note: Optimizations with the 2x the gradient will likely work just fine, but if someone wants to extend the code they might run into problems.)

from warp.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.