Git Product home page Git Product logo

astra-pytorch's Introduction

astra-pytorch

Setting up AstraDB

  1. Go to astra.datastax.com and sign up for a free account.
  2. Create a database using the add database button
  3. Go to the CQL Console and create the required tables a. CREATE TABLE mnist_digits.raw_train (id int PRIMARY KEY, label int, pixels list); b. CREATE TABLE mnist_digits.raw_test (id int PRIMARY KEY, label int, pixels list); c. CREATE TABLE mnist_digits.raw_predict (id int PRIMARY KEY, label int, pixels list); d. CREATE TABLE mnist_digits.models (id uuid PRIMARY KEY, network blob, optimizer blob, upload_date timestamp, comments text);
  4. Download a secure connect bundle
  5. Generate a Database Administrator token

Connecting to AstraDB in the code environment

  1. Load the secure connect bundle into the environment
  2. Input the bundle filepath, the client id, and the token into auth.py

Otherwise setting up the code environment

  1. Install python rquirements using pip3 install -r requirements.txt
  2. Load data into Astra by running load_raw_data - this may take an hour or more to complete

Running through the example

  1. From there you should be able to step through the notebook and train a model

The first thing we do in the notebook is import things neccesary for creating a Pytorch Dataset and Data Loader that connect to Astra. A data loader holds a certain amount of data this it provides to Pytorch in batches when asked. A Dataset defines how the loader pulls that data when asked.

We then create this AstraDataset. In this case we define a dataset that queries for individual rows and tranforms those rows to be in a format that Pytorch can tranform into tensors. The data in the Astra table is a Cassandra list collection of integer pixel values, that translates into a Python list of integer values, which the loader tranforms in a scaled numpy array of floats, which Pytorch can then transform into a tensor and scale and shift again.

After that we import again. This time things that are needed for the definition and training of a Pytorch neural network.

First we create our connection to Astra and use that to create the train and test set data loaders. Then we define constants that are used in the creation of our nerual net. These do this like change the number of training epochs and the setting for the backpropagation step of training a neural net.

After that we create a class Net that defines the structure and forwards propagation step of our nueral net. These will be used to make predictions on data later, but when they are just initialized it does basically random things to the data.

So we create an optimizer that does a backprogatation step, using the training data to build a gradient and doing gradient descent steps to make the model adhere closer to what is shown by the training data. After a couple of these optimization steps we can create a model much better than change at identifying the digits.

We create functions to handle the training and testing steps and then run them to create our model. During the training step we also save the current model with a record of its performance on the training set as well as the optimizer state so we could pcik up trianing from that point in the future if desired.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.