Git Product home page Git Product logo

forward-forward's People

Contributors

loewex avatar whubaichuan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

forward-forward's Issues

I got some runtime error, but i'm not sure what's wrong?

I followed the description in readme, but under win10. And I got these error. Can someone explain what's happening to me?

seed: 42
device: cuda
input:
  path: datasets
  batch_size: 100
model:
  peer_normalization: 0.03      
  momentum: 0.9
  hidden_dim: 1000
  num_layers: 3
training:
  epochs: 100
  learning_rate: 0.001
  weight_decay: 0.0003
  momentum: 0.9
  downstream_learning_rate: 0.01
  downstream_weight_decay: 0.003
  val_idx: -1
  final_test: false

FF_model(
  (model): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): Linear(in_features=1000, out_features=1000, bias=True)
    (2): Linear(in_features=1000, out_features=1000, bias=True)
  )
  (ff_loss): BCEWithLogitsLoss()
  (linear_classifier): Sequential(
    (0): Linear(in_features=2000, out_features=10, bias=False)
  )
  (classification_loss): CrossEntropyLoss()
) 

Error executing job with overrides: []
Traceback (most recent call last):
  File "C:\my_file\0_research\2023_FF\Forward-Forward\main.py", line 72, in my_main
    model = train(opt, model, optimizer)
  File "C:\my_file\0_research\2023_FF\Forward-Forward\main.py", line 20, in train
    for inputs, labels in train_loader:
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
    data = self._next_data()
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\dataloader.py", line 1224, in _next_data
    return self._process_data(data)
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\dataloader.py", line 1250, in _process_data
    data.reraise()
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\_utils.py", line 457, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\User\anaconda3\envs\FF\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\my_file\0_research\2023_FF\Forward-Forward\src\ff_mnist.py", line 15, in __getitem__
    pos_sample, neg_sample, neutral_sample, class_label = self._generate_sample(
  File "C:\my_file\0_research\2023_FF\Forward-Forward\src\ff_mnist.py", line 58, in _generate_sample
    neg_sample = self._get_neg_sample(sample, class_label)
  File "C:\my_file\0_research\2023_FF\Forward-Forward\src\ff_mnist.py", line 43, in _get_neg_sample
    one_hot_label = torch.nn.functional.one_hot(
RuntimeError: one_hot is only applicable to index tensor.


Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Question regarding logits = sum_of_squares - z.shape[1]

On file src.ff_model.py line 76:
logits = sum_of_squares - z.shape[1]
ff_loss = self.ff_loss(logits, labels.float())

I wondered for a long time and did not figure out why you wrote (sum_of_squares - z.shape[1]). Can you clarify it for me?
Besides, is the implementation of ff_loss the same as Hinton's original paper? From my understanding, it is not the same, and why do you make this modification?

error in main_model_params

hi, nice work.

The optimizer for the main model seems to be wrong. The code in this file:

if all(p is not x for x in model.classification_loss.parameters()) but model.classification_loss.parameters() is empty. Thus, I guess the code should be changed to: if all(p is not x for x in model.linear_classifier.parameters()) ?

How does the implementation achieve layer-by-layer training?

Hey, for someone not super familiar with how PyTorch autograd works internally:
Would you be able to shed some light on how your implementation realizes the two forward passes and layer by layer training?

(It seems like you do everything in one pass (positive, negative, layer-by-layer training), while in other implementations it is more obvious how they perform layer by layer optimization.)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.