Git Product home page Git Product logo

dl-project-template's Introduction

DL Project Template

深度学习工程模板,简化加载数据、构建网络、训练模型和预测样本的流程。

By C. L. Wang @ 美图云事业部

工程:https://github.com/SpikeKing/DL-Project-Template

使用方式

下载工程

git clone https://github.com/SpikeKing/DL-Project-Template

创建和激活虚拟环境

virtualenv venv
source venv/bin/activate

安装Python依赖库

pip install -r requirements.txt

开发流程

  1. 定义自己的数据加载类,继承DataLoaderBase;
  2. 定义自己的网络结构类,继承ModelBase;
  3. 定义自己的模型训练类,继承TrainerBase;
  4. 定义自己的样本预测类,继承InferBase;
  5. 定义自己的配置文件,写入实验的相关参数;

执行训练模型和预测样本操作。

示例工程

识别MNIST库中手写数字,工程simple_mnist

训练:

python main_train.py -c configs/simple_mnist_config.json

预测:

python main_test.py -c configs/simple_mnist_config.json -m simple_mnist.weights.10-0.24.hdf5

网络结构

网络结构

TensorBoard

TensorBoard

工程架构

框架图

架构

文件夹结构

├── bases
│   ├── data_loader_base.py             - 数据加载基类
│   ├── infer_base.py                   - 预测样本(推断)基类
│   ├── model_base.py                   - 网络结构(模型)基类
│   ├── trainer_base.py                 - 训练模型基类
├── configs                             - 配置文件夹
│   └── simple_mnist_config.json
├── data_loaders                        - 数据加载文件夹
│   ├── __init__.py
│   ├── simple_mnist_dl.py
├── experiments                         - 实验数据文件夹
│   └── simple_mnist                    - 实验名称
│       ├── checkpoints                 - 存储的模型和参数
│       │   └── simple_mnist.weights.10-0.24.hdf5
│       ├── images                      - 图片
│       │   └── model.png
│       └── logs                        - 日志,如TensorBoard
│           └── events.out.tfevents.1524034653.wang
├── infers                              - 推断文件夹
│   ├── __init__.py
│   ├── simple_mnist_infer.py
├── main_test.py                        - 预测样本入口
├── main_train.py                       - 训练模型入口
├── models                              - 网络结构文件夹
│   ├── __init__.py
│   ├── simple_mnist_model.py
├── requirements.txt                    - 依赖库
├── trainers                            - 训练模型文件夹
│   ├── __init__.py
│   ├── simple_mnist_trainer.py
└── utils                               - 工具文件夹
    ├── __init__.py
    ├── config_utils.py                 - 配置工具类
    ├── np_utils.py                     - NumPy工具类
    ├── utils.py                        - 其他工具类

主要组件

DataLoader

操作步骤:

  1. 创建自己的加载数据类,继承DataLoaderBase基类;
  2. 覆写get_train_data()get_test_data(),返回训练和测试数据;

Model

操作步骤:

  1. 创建自己的网络结构类,继承ModelBase基类;
  2. 覆写build_model(),创建网络结构;
  3. 在构造器中,调用build_model()

注意:plot_model()支持绘制网络结构;

Trainer

操作步骤:

  1. 创建自己的训练类,继承TrainerBase基类;
  2. 参数:网络结构model、训练数据data;
  3. 覆写train(),fit数据,训练网络结构;

注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。

Infer

操作步骤:

  1. 创建自己的预测类,继承InferBase基类;
  2. 覆写load_model(),提供模型加载功能;
  3. 覆写predict(),提供样本预测功能;

Config

定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。

Main

训练:

  1. 创建配置文件config;
  2. 创建数据加载类dataloader;
  3. 创建网络结构类model;
  4. 创建训练类trainer,参数是训练和测试数据、模型;
  5. 执行训练类trainer的train();

预测:

  1. 创建配置文件config;
  2. 处理预测样本test;
  3. 创建预测类infer;
  4. 执行预测类infer的predict();

感谢

参考Tensorflow-Project-Template

dl-project-template's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dl-project-template's Issues

可以升级一下 python 版本

目前主流都在往 python3 转,官方也说要停止对 python2 的支持了。
用 python 自带的 2to3.py 可以一键转换。

BTW,不错的模板。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.