Git Product home page Git Product logo

Comments (9)

Noahprog avatar Noahprog commented on July 17, 2024 8

I know what the problem is.

In the original Tensorflow implementation, the test evaluation it used to calculate MAE for every timestep separately. In the paper only the MAE for the last timestep is reported.

The next configurations are used to show an example.
Screenshot 2020-07-08 at 11 37 18

In the following log (ran with the original Tensorflow code), it shows the separate MAE's for all three timesteps. This Pytorch implementation does not do that, but shows only the average of those values.
Screenshot 2020-07-08 at 09 18 36

Here, the validation MAE (noted with an (1)) is equal to the exact same run with this Pytorch implementation. Also the MAE of the last timestep (noted by an (2)) is the same as the reported MAE in the published paper.

Concluding, both implementations are equally good. But, the Pytorch implementation lacks this separation of calculating MAE for every timestep. @chnsh this mistake could have been prevented..

P.S. This is the same problem as in issue #7 probably.

from dcrnn_pytorch.

chnsh avatar chnsh commented on July 17, 2024 1

@Noahprog thanks for digging! Good find, want to send a PR to fix it?

from dcrnn_pytorch.

cvignac avatar cvignac commented on July 17, 2024

It seems that the loss function is not exactly the same.

In the data there are two features: traffic speed and traffic volume (I don't know in what order). It seems that in the Pytorch version the loss is computed on this two features whereas in the tf DCRNN it is computed using only predictions for the first feature (line 272, https://github.com/liyaguang/DCRNN/blob/master/model/dcrnn_supervisor.py).

Is that correct?

from dcrnn_pytorch.

chnsh avatar chnsh commented on July 17, 2024

@cvignac that is a good catch - however, it seems like that may not be the issue, I think in the dataset, there is only 1 dimension and indexing on 0 as you pointed out does nothing to the values.

I've not been able to dig in deeper to figure out why the boost exists though.

from dcrnn_pytorch.

razvanc92 avatar razvanc92 commented on July 17, 2024

@chnsh I'm not sure that's the case. The dataset has two dimensions, speed and time, the model should predict both since they are required by the decoder when predicting the next step, but the evaluation should only be on the first dimension (speed).

from dcrnn_pytorch.

chnsh avatar chnsh commented on July 17, 2024

Oh, I see - thanks for pointing it out, I will try and re-evaluate as soon as possible, I will be out of office for some time though

from dcrnn_pytorch.

chnsh avatar chnsh commented on July 17, 2024

@razvanc92 I think you're right that the dataset has 2 dimensions (when it's constructed), but what is happening is that the final loss is calculated by slicing for the 1st dimension - the final outputs are in sized (12, 6912, 207) and the dataset is (12, 6912, 207, 2) so you see that the final loss calculation is on the correct dimensions, so the loss value is not wrong - can you confirm?

from dcrnn_pytorch.

Noahprog avatar Noahprog commented on July 17, 2024

I think this issue is still not sufficiently answered though.

I'm really curious about why the Pytorch version has better performance. I've been digging through the code quite some time, but haven't found anything suspicious. Will keep you up to date if I find something.

from dcrnn_pytorch.

semink avatar semink commented on July 17, 2024

I think the answer is related to #16

from dcrnn_pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.