Commit 07ca6887 by 前钰

Upload New File

parent b9f607e1
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from model import *\n",
"from data import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train your Unet with membrane data\n",
"membrane data is in folder membrane/, it is a binary classification task.\n",
"\n",
"The input shape of image and mask are the same :(batch_size,rows,cols,channel = 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train with data generator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\WilliamGinWolf\\AppData\\Roaming\\Python\\Python310\\site-packages\\keras\\optimizers\\optimizer_v2\\adam.py:110: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
" super(Adam, self).__init__(name, **kwargs)\n",
"C:\\Users\\WilliamGinWolf\\AppData\\Local\\Temp\\ipykernel_26284\\3820683722.py:11: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
" model.fit_generator(myGene,steps_per_epoch=2000,epochs=5,callbacks=[model_checkpoint])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 30 images belonging to 1 classes.\n",
"Found 30 images belonging to 1 classes.\n",
"Epoch 1/5\n",
"2000/2000 [==============================] - ETA: 0s - loss: 0.1944 - accuracy: 0.9145\n",
"Epoch 1: loss improved from inf to 0.19437, saving model to unet_membrane.hdf5\n",
"2000/2000 [==============================] - 258s 124ms/step - loss: 0.1944 - accuracy: 0.9145\n",
"Epoch 2/5\n",
"2000/2000 [==============================] - ETA: 0s - loss: 0.1241 - accuracy: 0.9455\n",
"Epoch 2: loss improved from 0.19437 to 0.12412, saving model to unet_membrane.hdf5\n",
"2000/2000 [==============================] - 242s 121ms/step - loss: 0.1241 - accuracy: 0.9455\n",
"Epoch 3/5\n",
"2000/2000 [==============================] - ETA: 0s - loss: 0.0969 - accuracy: 0.9577\n",
"Epoch 3: loss improved from 0.12412 to 0.09691, saving model to unet_membrane.hdf5\n",
"2000/2000 [==============================] - 242s 121ms/step - loss: 0.0969 - accuracy: 0.9577\n",
"Epoch 4/5\n",
"2000/2000 [==============================] - ETA: 0s - loss: 0.0823 - accuracy: 0.9641\n",
"Epoch 4: loss improved from 0.09691 to 0.08232, saving model to unet_membrane.hdf5\n",
"2000/2000 [==============================] - 242s 121ms/step - loss: 0.0823 - accuracy: 0.9641\n",
"Epoch 5/5\n",
"2000/2000 [==============================] - ETA: 0s - loss: 0.0737 - accuracy: 0.9679\n",
"Epoch 5: loss improved from 0.08232 to 0.07368, saving model to unet_membrane.hdf5\n",
"2000/2000 [==============================] - 241s 121ms/step - loss: 0.0737 - accuracy: 0.9679\n"
]
},
{
"data": {
"text/plain": "<keras.callbacks.History at 0x16316df1fc0>"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_gen_args = dict(rotation_range=0.2,\n",
" width_shift_range=0.05,\n",
" height_shift_range=0.05,\n",
" shear_range=0.05,\n",
" zoom_range=0.05,\n",
" horizontal_flip=True,\n",
" fill_mode='nearest')\n",
"myGene = trainGenerator(2,'data/membrane/train','image','label',data_gen_args,save_to_dir = None)\n",
"model = unet()\n",
"model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='loss',verbose=1, save_best_only=True)\n",
"model.fit_generator(myGene,steps_per_epoch=2000,epochs=5,callbacks=[model_checkpoint])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train with npy file"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#imgs_train,imgs_mask_train = geneTrainNpy(\"data/membrane/train/aug/\",\"data/membrane/train/aug/\")\n",
"#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### test your model and save predicted results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\WilliamGinWolf\\AppData\\Local\\Temp\\ipykernel_26284\\1405464076.py:4: UserWarning: `Model.predict_generator` is deprecated and will be removed in a future version. Please use `Model.predict`, which supports generators.\n",
" results = model.predict_generator(testGene,30,verbose=1)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"30/30 [==============================] - 2s 20ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\0_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\1_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\2_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\3_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\4_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\5_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\6_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\7_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\8_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\9_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\10_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\11_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\12_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\13_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\14_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\15_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\16_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\17_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\18_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\19_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\20_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\21_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\22_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\23_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\24_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\25_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\26_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\27_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\28_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n",
"F:\\Academic\\codes\\DT\\PROJECTS\\UNets\\data.py:125: UserWarning: data/membrane/test\\29_predict.png is a low contrast image\n",
" io.imsave(os.path.join(save_path,\"%d_predict.png\"%i),img)\n"
]
}
],
"source": [
"testGene = testGenerator(\"data/membrane/test\")\n",
"model = unet()\n",
"model.load_weights(\"unet_membrane.hdf5\")\n",
"results = model.predict_generator(testGene,30,verbose=1)\n",
"saveResult(\"data/membrane/test\",results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment