Comments (3)
this is the whole code
# -*- encoding:utf-8-*-
import os
import glob
import pickle
import datetime
import numpy as np
from keras.layers import (Conv3D, AveragePooling3D, MaxPooling3D, Activation, UpSampling3D, merge, Input, Reshape,
Permute)
from keras import backend as K
from keras.models import Model, load_model
from keras.optimizers import Adam
import SimpleITK as sitk
# The BRATS dataset also contains T2 scans
pool_size = (2, 2, 2)
image_shape = (144, 240, 240)
n_channels = 2
# n_channels is the number of modalities (T1c, FLAIR(核磁共振反转图像), etc.
input_shape = tuple([n_channels] + list(image_shape))
n_labels = 5
batch_size = 1
n_test_subjects = 40
z_crop = 155 - image_shape[0]
training_iterations = 5
def pickle_dump(item, out_file):
with open(out_file, "wb") as opened_file:
pickle.dump(item, opened_file)
def pickle_load(in_file):
with open(in_file, "rb") as opened_file:
return pickle.load(opened_file)
K.set_image_dim_ordering('th')
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def unet_model():
inputs = Input(input_shape)
conv1 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(inputs)
conv1 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(conv1)
pool1 = MaxPooling3D(pool_size=pool_size)(conv1)
conv2 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(pool1)
conv2 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(conv2)
pool2 = MaxPooling3D(pool_size=pool_size)(conv2)
conv3 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(pool2)
conv3 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(conv3)
pool3 = MaxPooling3D(pool_size=pool_size)(conv3)
conv4 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(pool3)
conv4 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(conv4)
pool4 = MaxPooling3D(pool_size=pool_size)(conv4)
conv5 = Conv3D(512, 3, 3, 3, activation='relu', border_mode='same')(pool4)
conv5 = Conv3D(512, 3, 3, 3, activation='relu', border_mode='same')(conv5)
up6 = merge([UpSampling3D(size=pool_size)(conv5), conv4], mode='concat', concat_axis=1)
conv6 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(up6)
conv6 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(conv6)
up7 = merge([UpSampling3D(size=pool_size)(conv6), conv3], mode='concat', concat_axis=1)
conv7 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(up7)
conv7 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(conv7)
up8 = merge([UpSampling3D(size=pool_size)(conv7), conv2], mode='concat', concat_axis=1)
conv8 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(up8)
conv8 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(conv8)
up9 = merge([UpSampling3D(size=pool_size)(conv8), conv1], mode='concat', concat_axis=1)
conv9 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(up9)
conv9 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(conv9)
conv10 = Conv3D(n_labels, 1, 1, 1)(conv9)
act = Activation('sigmoid')(conv10)
model = Model(input=inputs, output=act)
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
return model
def train_batch(batch, model):
x_train = batch[:,:2]
y_train = get_truth(batch)
del(batch)
print(model.train_on_batch(x_train, y_train))
del(x_train)
del(y_train)
def read_subject_folder(folder):
flair_image = sitk.ReadImage(os.path.join(folder, "Flair_subtrMeanDivStd.nii.gz"))
# t1_image = sitk.ReadImage(os.path.join(folder, "T1.nii.gz"))
t1c_image = sitk.ReadImage(os.path.join(folder, "T1c_subtrMeanDivStd.nii.gz"))
truth_image = sitk.ReadImage(os.path.join(folder, "OTMultiClass.nii.gz"))
#background_image = sitk.ReadImage(os.path.join(folder, "background.nii.gz"))
return np.array([#sitk.GetArrayFromImage(t1_image),
sitk.GetArrayFromImage(t1c_image),
sitk.GetArrayFromImage(flair_image),
sitk.GetArrayFromImage(truth_image)])
#sitk.GetArrayFromImage(background_image)
# def crop_data(data, background_channel=4):
# if np.all(data[background_channel, :z_crop] == 1):
# return data[:, z_crop:]
# elif np.all(data[background_channel, data.shape[1] - z_crop:] == 1):
# return data[:, :data.shape[1] - z_crop]
# else:
# upper = z_crop/2
# lower = z_crop - upper
# return data[:, lower:data.shape[1] - upper]
def crop_data(data,z_crop):
return data[:,z_crop:]
def get_truth(batch, truth_channel=2):
truth = np.array(batch)[:, truth_channel]
batch_list = []
for sample_number in range(truth.shape[0]):
sample_list = []
for label in range(n_labels):
array = np.zeros(truth[sample_number].shape)
array[truth[sample_number] == label] = 1
sample_list.append(array)
batch_list.append(sample_list)
return np.array(batch_list)
def get_subject_id(subject_dir):
return subject_dir.split("_")[-2]
def main(overwrite=False):
model_file = os.path.abspath("3d_unet_model.h5") #返回path规范化的绝对路径
if not overwrite and os.path.exists(model_file):
model = load_model(model_file, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
else:
model = unet_model()
train_model(model, model_file, overwrite=overwrite, iterations=training_iterations)
def get_subject_dirs():
return glob.glob("/home/kaido/workspace/3DUnet/test/*")
def train_model(model, model_file, overwrite=False, iterations=1):
for i in range(iterations):
processed_list_file = os.path.abspath("processed_subjects.pkl")
if overwrite or not os.path.exists(processed_list_file) or i > 0:
processed_list = []
else:
processed_list = pickle_load(processed_list_file)
subject_dirs = get_subject_dirs()
testing_ids_file = os.path.abspath("testing_ids.pkl")
if os.path.exists(testing_ids_file) and not overwrite:
testing_ids = pickle_load(testing_ids_file)
if len(testing_ids) > n_test_subjects:
testing_ids = testing_ids[:n_test_subjects]
pickle_dump(testing_ids, testing_ids_file)
else:
# reomove duplicate sessions
subjects = dict()
for dirname in subject_dirs:
subjects[dirname.split('_')[-2]] = dirname
subject_ids = subjects.keys()
np.random.shuffle(subject_ids)
testing_ids = subject_ids[:n_test_subjects]
pickle_dump(testing_ids, testing_ids_file)
batch = []
for subject_dir in subject_dirs:
subject_id = get_subject_id(subject_dir)
if subject_id in testing_ids or subject_id in processed_list:
continue
processed_list.append("Flair_subtrMeanDivStd.nii.gz")
processed_list.append("T1c_subtrMeanDivStd.nii.gz")
processed_list.append("OTMultiClass.nii.gz")
batch.append(crop_data(read_subject_folder('/home/kaido/workspace/3DUnet/train')))
if len(batch) >= batch_size:
train_batch(np.array(batch), model)
del(batch)
batch = []
print("Saving: " + model_file)
pickle_dump(processed_list, processed_list_file)
model.save(model_file)
if batch:
train_batch(np.array(batch), model)
del(batch)
print("Saving: " + model_file)
pickle_dump(processed_list, processed_list_file)
model.save(model_file)
if __name__ == "__main__":
main(overwrite=False)
from 3dunetcnn.
Hi @Kaido0, the error is due to a difference in file names. I named my files a certain way so that I could get the subject_id from the file name. Since your files are not named the same way, that is why you are getting the error.
Have you looked into making a custom data generator for your data? That's what I would suggest using to train the model. I'm working on making a data generator and adding to this repository, but I'm not done with it yet. You can take a look at that code and maybe it will lead you in the right direction.
from 3dunetcnn.
@Kaido0 checkout the latest code and #5 as I have updated the code a lot so that it is better able to handle input data.
from 3dunetcnn.
Related Issues (20)
- Valued Error on BraTS20 Training HOT 6
- scripts problems HOT 9
- problems in cross validation training HOT 8
- About Evaluation metrics for BraTS example HOT 3
- KeyError: 'config' during evaluation on brats example HOT 4
- I didn't make any changes. I couldn't find the file when I ran the sample file. HOT 2
- An error occurred during the evaluation HOT 2
- Issue with predict.py and resampling scans back into the original space. HOT 2
- The train.py seems no args of "model_filename" and "training_log_filename" HOT 2
- 請問已經訓練出模型出來,該如何去測試模型分割是否理想? HOT 6
- How to use it to train in my own dataset HOT 2
- I'm really sorry for bothering you. HOT 3
- I'm really sorry for bothering you HOT 2
- Unable to clone HOT 2
- Runtime errror HOT 5
- Runtime error
- ImportError: attempted relative import with no known parent package HOT 2
- Loss value is abnormal HOT 2
- predict.py issue HOT 1
- new issue HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from 3dunetcnn.