Git Product home page Git Product logo

greendev7 / neuralnetwork Goto Github PK

View Code? Open in Web Editor NEW
0.0 2.0 1.0 53.04 MB

Реализация алгоритма обратного распространения ошибки для обучения нейронной сети для распознавания рукописных цифр

Home Page: https://codesandbox.io/s/winter-dew-u26xeb

C# 100.00%
digit-recognition-problem digit-recognition-application activation-functions backpropagation backpropagation-algorithm backpropagation-learning-algorithm backpropagation-neural-network digit-recognition-mnist handwriting-recognition handwritten-digit-recognition

neuralnetwork's Introduction

NeuralNetwork

Реализация классического алгоритма обратного распространения ошибки с помощью последовательного (стохастического) режима обучения для тренировки многослойного персептрона распознаванию рукописных цифр

Digit Recognition Canvas (песочница)

Digit Recognition Canvas (исходники)

Описание алгоритма (Хабр)

Form1.cs

private void learnButton_Click(object sender, EventArgs e)
        {
            string myDocumentFolder = Environment.GetFolderPath(Environment.SpecialFolder.MyDocuments);

            string trainingImagesPath = Path.Combine(myDocumentFolder, "train-images-idx3-ubyte");
            string trainingLabelsPath = Path.Combine(myDocumentFolder, "train-labels-idx1-ubyte");

            string testImagesPath = Path.Combine(myDocumentFolder, "t10k-images-idx3-ubyte");
            string testLabelsPath = Path.Combine(myDocumentFolder, "t10k-labels-idx1-ubyte");

            #region Блок для инициализации нейросети с помощью весовых коэффициентов из файлов csv

            //// Считываем весовые коэффициенты из файлов
            // List<Layer> hiddenLayers = InitializeHiddenLayersWeightsFromCSVFile(Path.Combine(myDocumentFolder, "adjustedHiddenLayerWeights_acc9572_16.csv"));
            // Layer outputLayer = InitializeOutputLayerWeightsFromCSVFile(Path.Combine(myDocumentFolder, "adjustedOutputLayerWeights_acc9572_16.csv"));
            ////Инициализируем нейросеть
            //Network network = new Network(hiddenLayers, outputLayer);
            
            #endregion


            #region Блок для инициализация нейросети рандомными значениями и ее обучение

            // Инициализируем нейросеть с помощью заданных параметров

            int hiddenLayersCount = 1;  // Задаем количество скрытых слоев
            int[] hiddenLayersDimensions = new int[hiddenLayersCount]; // Массив для хранения количества нейронов на каждом скрытом слое
            Func<double, double>[] hiddenActivationFunctions = new Func<double, double>[hiddenLayersCount]; // Массив для хранения функций активации на каждом скрытом слое

            hiddenLayersDimensions[0] = 80; // У нас один скрытый слой на котором 80 нейронов
            hiddenActivationFunctions[0] = ActivationFunctions.SigmoidFunction; // И для всех нейронов этого скрытого слоя используется сигмоидальная функция активации

            // 784 входа - это размер массива полученного из изображения (28 * 28 пикселей)
            Network network = new Network(784, 10, ActivationFunctions.SigmoidFunction, hiddenLayersDimensions, hiddenActivationFunctions);
            network.Train(trainingImagesPath, trainingLabelsPath, 0.2, 16); // Запускаем обучение

            #endregion


            #region Тестируем нейросеть на тестовой выборке в 10 000 изображений

            // Получаем тестовые изображения
            IEnumerable<TestCase> testCases = FileReaderMNIST.LoadImagesAndLables(testLabelsPath, testImagesPath);

            int incorrectPredictionsCount = 0; // счетчик неверно предсказанных результатов
            foreach (TestCase test in testCases)
            {
                List<double> functionSignal = ImageHelper.ConvertImageToFunctionSignal(test.Image); // Преобразуем изображение в вектор размерности 784 состоящий из нулей и единичек

                List<double> outputSignal = network.MakePropagateForward(functionSignal); // Получаем сигнал от нейросети
                int predictedDigit = outputSignal.IndexOf(outputSignal.Max()); // Предсказанную цифру находим как индекс максимального элемента массива

                // Если нейросеть выдала некорректный ответ
                if (test.Label != predictedDigit)
                {
                    incorrectPredictionsCount++;
                    // Получим это изображение
                    Bitmap bitmap = ImageHelper.CreateBitmapFromMnistImage(test.Image);                    
                    // И сохраним в папку IncorrectPredictions
                    bitmap.Save(Path.Combine(myDocumentFolder, "IncorrectPredictions", $"{incorrectPredictionsCount}_{test.Label}_{predictedDigit}.png"));
                }
            }

            double accuracy = 100.0 - (incorrectPredictionsCount / 100.0); // Вычисляем точность (%)
            #endregion


            // Записываем скорректированные весовые коэффициенты в файлы
            network.WriteHiddenWeightsToCSVFile(Path.Combine(myDocumentFolder, $"adjustedHiddenLayerWeights_acc{accuracy.ToString().Replace(",", string.Empty)}.csv"));
            network.WriteOutputWeightsToCSVFile(Path.Combine(myDocumentFolder, $"adjustedOutputLayerWeights_acc{accuracy.ToString().Replace(",", string.Empty)}.csv"));
            // и в JSON файлы
            network.WriteHiddenWeightsToJsonFile(Path.Combine(myDocumentFolder, $"adjustedHiddenLayerWeights_acc{accuracy.ToString().Replace(",", string.Empty)}.json"));
            network.WriteOutputWeightsToJsonFile(Path.Combine(myDocumentFolder, $"adjustedOutputLayerWeights_acc{accuracy.ToString().Replace(",", string.Empty)}.json"));
        }

neuralnetwork's People

Watchers

 avatar  avatar

Forkers

pivopls

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.