Git Product home page Git Product logo

mo-ahsan-ahmad's Introduction

Mo-Ahsan-Ahmad

COSC 5P77, SPN project under the guidance of Prof. Dr. Li.

Dept. of Comp. Science, Brock University, ON, CA

Sum-Product Networks (SPNs) are a type of deep neural network, specifically a type of deep probabilistic graphical model. SPNs combine ideas from graphical models and neural networks to form a hierarchical structure capable of representing complex probability distributions.

In our practical approach, we imported the MNIST dataset using PyTorch and normalized it with a transformation (mean=0.5, standard deviation=0.5). The dataset included 60,000 training samples and 10,000 test samples. For efficient processing, we divided the data into 64-batch size. As a result, our training loader contained 938 (60000/64≈938), similarly the test loader contained 157 (10000/64≈157) batches of data, each containing 64 samples. Moving forward, we next initialized the weights for our Sum-Product Network (SPN) class and configured its architecture and settings for training and inference. This systematic methodology paved the way for implementing and experimenting with SPNs on the MNIST dataset.
Initializing Weights: Define a function to initialize the weights of the neural network layers using Xavier initialization.
SPN Model Definition: The SimpleSPN class is defined as a subclass of torch.nn.Module. It takes input size, sizes of sum and product layers, and the number of classes as parameters. Inside the constructor (init), sum and product layers are defined as lists of nn.Sequential and other modules respectively, and an output layer is also defined. Forward Pass: In the forward method of the SimpleSPN class, input data is processed through sum and product layers sequentially. softmax or ReLU activation function is applied after each sum layer. Finally, the output layer produces the classification predictions.
Cross-Entropy Loss: nn.CrossEntropyLoss is chosen as the loss function for the classification challenge. This criterion computes the Cross Entropy Loss between the predicted and target labels during training.
Optimizer Setup: Three distinct optimizers (Adam, RMSprop, and Adamax) are used to train the SPN model. Each optimizer begins with the SPN model's parameters and a learning rate.
Training Function: The train_model function is used to train the SPN model over a set number of epochs using the provided optimizer and criterion. Inside the function, the model is switched to train mode, and training data is iterated in batches. Loss is calculated, gradients are computed using backpropagation, and the optimizer adjusts the model parameters.
Testing Function: The test_model function assesses the trained model's performance versus the test dataset. The model is configured for evaluation mode, and test data is iterated in batches. Loss and accuracy metrics are calculated and reported.
Training and Testing: Finally, the code loops through each optimizer, uses weight initialization, trains the model using the train_model function, and evaluates its performance on the test dataset using the test_model function.

mo-ahsan-ahmad's People

Contributors

rebelahsan avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.