Git Product home page Git Product logo

t60_train_test's Introduction

T60_train_test

-------------训练部分------------

1.改写了new_data_load_original.py,可以自动判断不同的数据类型。

即老数据中的image为单个频段,新数据中image为多个频段,可以自动判断并读取。

2.直接运行 python train_resnet50_500Hz.py(单卡运行)

使用时需要传入两个不同的数据路径:train_dict_root以及train_dict_root1,即支持多个数据集一块训练,其中train_dict_root1可以为None

train_dict_root = "/data/xbj/0927_500hz_with_clean/train"
train_dict_root1 = "/data2/hsl/0324_pt_with_clean"  # 新数据地址

代码中具体调用方式为:

train_transformed_dataset = Dataset_dict(root_dir=train_dict_root,root_dir1=train_dict_root1, transform=data_transform,
                                             start_freq=args.start_freq, end_freq=args.end_freq, which_freq=args.which_freq, failed_file=failed_file)

此外,还需更改训练的频率,各个频率对应为 125Hz:(0, 7); 250Hz:(1, 10); 500Hz:(2, 13); 1kHz:(3, 16); 2kHz:(4, 19).

    parser.add_argument('--which_freq', type=int, default=2)   # 记得改成对应的图编号
    parser.add_argument('--freq_loc', type=int, default=13)    # 记得改成对应的倍频程

3.多卡并行训练,单进程、多线程(不推荐)

python train_resnet50_500Hz_multi_card.py

4.多卡并行训练,多进程、多线程(推荐)

python -m torch.distributed.launch --nnodes=1 --nproc_per_node=3 train_resnet50_1000Hz_multi_card_DDP.py 其中nnodes为机器数,不用改,nproc_per_node为显卡数,更快更推荐,使用时记得更改数据集路径及选择频率,步骤与上面一样

--------------训练部分补充--------------

train_resnet50_1000Hz_multi_card_DDP_fixed.py 修补了train_resnet50_1000Hz_multi_card_DDP.py运行的一些问题,在运行程序时,采用train_resnet50_1000Hz_multi_card_DDP_fixed.py
new_data_load_original_fixed.py 修补了new_data_load_original.py运行的一些问题,在运行程序时,采用new_data_load_original_fixed.py

-------------测试部分------------

230630新增:auto_test.py

可以批量测试,比较简单,详见代码

T60_train_test/test_resnet50_500Hz_old

Code Description

Main documents for testing

test_resnet50_500Hz.py

The *.py is used to test previous data, and it can test different frequency. You can run your own test by the Run Instructions.

Run Instructions

If you want to test the newly generated data, you should revise the following place:
1、The default of the following two codes should be modified to the path of your pretrained model.

parser.add_argument('--save_dir', type=str, default="/data2/pyq/yousonic/old_500hz/1001_autoLoss/save_model/")  
parser.add_argument('--model_path', type=str, default="/data2/pyq/yousonic/old_500hz/1001_autoLoss/save_model/")  

2、The default represents the number of epochs.

parser.add_argument('--epoch_for_save', type=int, default=60)  

3、The defaults are used to choose frequency, and the corresponding relationships are 125Hz:(0, 7); 250Hz:(1, 10); 500Hz:(2, 13); 1kHz:(3, 16); 2kHz:(4, 19).

parser.add_argument('--which_freq', type=int, default=2)   # 记得改成对应的图编号
parser.add_argument('--freq_loc', type=int, default=13)    # 记得改成对应的倍频程

4、The default should be modified to your own save-path.

parser.add_argument('--outputresult_dir', type=str, default="./0329_oldmodel_yqhdata_500hz/epoch60")  

5、The val_dict_root and failed file should be modified to your own data path and you can revise the val_batch_size. However your data must have clean !!!

if DEBUG == False:
    val_dict_root = "/Users/bajianxiang/Desktop/internship/filter_down_spec_dataset"
    val_batch_size = 1
    failed_file = "/Users/bajianxiang/Desktop/internship/filter_down_spec_dataset/koli-national-park-winter/koli-national-park-winter_koli_snow_site4_1way_bformat_1_koli-national-park-winter_东北话男声_1_TIMIT_a005_100_110_10dB-0.pt"
else:
    val_dict_root = "/data/xbj/0927_500hz_with_clean/val"
    failed_file = "/data/xbj/0927_500hz_with_clean/val/hoffmann-lime-kiln-langcliffeuk/hoffmann-lime-kiln-langcliffeuk_ir_p2_0_1_hoffmann-lime-kiln-langcliffeuk_东北话男声_1_TIMIT_a053_110_120_10dB-0.pt"
    val_batch_size = 6  

6、If you use model train through multi-card, you should add the following code.

   net = nn.DataParallel(net)

Files that do not need to be changed!!!

model
utils
data_load.py
data_load_aug2.py
data_load_for_FPN_spec.py
eval_same_room.py
gen_boxplot.py
get_deverb_spec_test.py
gtg.py
load_same_room.py
new_data_load_original.py new_data_load_original_for_our_data.py
new_data_load_origin_old.py
new_thread_test.py
splweighting.py
SSIMLoss.py
Test_new_dataset.py
test_our_data.py
thread_test.py
train_resnet50_500Hz.py
valdata_meant60.py
workspace_utils.py

t60_train_test's People

Contributors

max-vegetable avatar pyq2377 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.