Git Product home page Git Product logo

Comments (11)

github-actions avatar github-actions commented on July 26, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

calebclayreagor avatar calebclayreagor commented on July 26, 2024

I'm running into the same issue with PrecisionRecallCurve and calling .compute() during or after any step if using ddp

from torchmetrics.

maximsch2 avatar maximsch2 commented on July 26, 2024

@calebclayreagor , how big is your data? PrecisionRecallCurve stores the entire dataset in memory and on compute() it needs to consolidate it on single rank. If your model is big and dataset is big, then at some point model + all predictions/labels will not fit into GPU memory which will lead to NCCL issues/GPU OOMs/hangs.

I have a solution for PrecisionRecall-based metrics in #128 by doing binning to make the compute constant-memory (as opposed O(dataset size) right now). You do trade off a bit of accuracy for it and have to specify number of thresholds to use.

from torchmetrics.

Borda avatar Borda commented on July 26, 2024

@SkafteNicki have you check this issue? 🐰

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 26, 2024

@Borda cannot debug myself currently as my local cluster is under maintenance

from torchmetrics.

calebclayreagor avatar calebclayreagor commented on July 26, 2024

@maximsch2 my dataset is quite large (>6m examples) but it still fits in memory. My problem was actually due to ghost processes, and I solved the issue by doing kill -9 <pid> before starting training in ddp mode. Newbie error.

from torchmetrics.

Borda avatar Borda commented on July 26, 2024

@justusschock mind have look?

from torchmetrics.

maximsch2 avatar maximsch2 commented on July 26, 2024

@vilon888 , are you still seeing this issue? Can you check one thing - does your dataloader produce batches of the same size for all workers all the time?

@SkafteNicki , I've finally debugged a similar issue we've been having and it's due to handling of datasets that don't divide evenly in the full number of batches - this makes it so that last batch is partial and different lengths on different workers. This makes preds/target tensors different shape and that breaks gather_all_tensors. We probably need to be smarter there - first gather max shape across all workers, then pad the resulting tensor to that max shape, then truncate back.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 26, 2024

@maximsch2 I agree that could be a problem.
Are you also seeing this when using torchmetrics as standalone or also when used together with lightning?

The reason that I am asking is that lightning as default is using Pytorchs DistributedSampler that will add additional samples to make sure that all processes gets equal workload
https://github.com/pytorch/pytorch/blob/87242d2393119990ebe9043e854317f02536bdff/torch/utils/data/distributed.py#L105-L114

from torchmetrics.

maximsch2 avatar maximsch2 commented on July 26, 2024

DistributedSamper doesn't work for IterableDataset which is what we usually get due to reading from databases, so we never really use that sampler. This is the fix btw: #220

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 26, 2024

Closing as this as it should have been solved by #220.
Please re-open if the error persist.

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.