Skip to content
Snippets Groups Projects
Commit 6f35228f authored by Liang Cheng's avatar Liang Cheng
Browse files

update

parent f7840a39
No related branches found
No related tags found
No related merge requests found
......@@ -109,6 +109,8 @@ end
| train_models | device,models,modelsindex,train_dataset,lr,optimizer,epochs |NULL |用于训练私有数据集在MNIST数据集上|
由于这里batch size设置小了,所以有点毛病
![pretrained_public_mnist_initial_result](Src/Figure/private_model_public_dataset_initial_train_losses.png)
### pretrained_public_mnist_continue.py
用于继续训练各个locla的私有模型在MNIST数据集上,直到模型收敛
......
Src/Figure/collaborative_private_model_private_dataset_train_losses.png

40.4 KiB | W: | H:

Src/Figure/collaborative_private_model_private_dataset_train_losses.png

51.4 KiB | W: | H:

Src/Figure/collaborative_private_model_private_dataset_train_losses.png
Src/Figure/collaborative_private_model_private_dataset_train_losses.png
Src/Figure/collaborative_private_model_private_dataset_train_losses.png
Src/Figure/collaborative_private_model_private_dataset_train_losses.png
  • 2-up
  • Swipe
  • Onion skin
Src/Figure/collaborative_train_losses.png

14.2 KiB | W: | H:

Src/Figure/collaborative_train_losses.png

46.8 KiB | W: | H:

Src/Figure/collaborative_train_losses.png
Src/Figure/collaborative_train_losses.png
Src/Figure/collaborative_train_losses.png
Src/Figure/collaborative_train_losses.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -5,21 +5,9 @@ import torch
from torch.utils.data import DataLoader,Dataset
from torch import nn
import matplotlib.pyplot as plt
from utils import get_model_list
def get_model_list(url,modelsindex,models):
model_list = []
model_type_list = []
filePath = url
for root, dirs, files in os.walk(filePath, topdown=False):
for name in files:
model_type_list.append(int(name[name.find('Type')+4]))
net = models[modelsindex[int(name[name.find('Type')+4])]]()
net.load_state_dict(torch.load(os.path.join(root, name)))
model_list.append(net)
return model_list,model_type_list
class DatasetSplit(Dataset):
"""
An abstract Dataset class wrapped around Pytorch Dataset class
......
......@@ -28,13 +28,13 @@ def args_parser():
# Collaborative_private_model_femnist_balanced
parser.add_argument('--new_collaborative_training',type=bool,default=False,help='whether train model from initial condition')
parser.add_argument('--Collaborativeurl',type=str,default='Src\CollaborativeModel',help='collaborative model location')
parser.add_argument('--collaborative_epoch',type=int,default=1,help='collaborative_epoch for train on public mnist')
parser.add_argument('--collaborative_epoch',type=int,default=4,help='collaborative_epoch for train on public mnist')
# Collaborative_step
parser.add_argument('--Communicationepoch',type=int,default=10,help='Collaobrative epoch in Step3')
# collaborative_private_model_femnist_balanced
parser.add_argument('--Communication_private_epoch',type=int,default=2 ,help='Local private training during colaboratiive time')
parser.add_argument('--Communication_private_epoch',type=int,default=8 ,help='Local private training during colaboratiive time')
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment