Skip to content
Snippets Groups Projects
collaborative_private_model_femnist_balanced.py 3.72 KiB
Newer Older
from  data_utils import  get_public_dataset, get_private_dataset_balanced,FEMNIST_iid
from  models import CNN_2layer_fc_model,CNN_3layer_fc_model
import os
import torch
from torch.utils.data import DataLoader,Dataset
from torch import nn
import matplotlib.pyplot as plt
Liang Cheng's avatar
Liang Cheng committed
from utils import get_model_list
Liang Cheng's avatar
Liang Cheng committed

class DatasetSplit(Dataset):
    """
    An abstract Dataset class wrapped around Pytorch Dataset class
    """
    def __init__(self,dataset,idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]
    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image,label = self.dataset[self.idxs[item]]
        return torch.tensor(image),torch.tensor(label)

Liang Cheng's avatar
Liang Cheng committed
def collaborative_private_model_femnist_train(args):
    device = 'cuda' if args.gpu else 'cpu'
    # 用于初始化模型的部分
    # 获得FEMNIST数据集!
Liang Cheng's avatar
Liang Cheng committed
    train_dataset, test_dataset = get_private_dataset_balanced(args)
    user_groups = FEMNIST_iid(train_dataset, args.user_number)

    models = {"2_layer_CNN": CNN_2layer_fc_model,  # 字典的函数类型
Liang Cheng's avatar
Liang Cheng committed
              "3_layer_CNN": CNN_3layer_fc_model}
    modelsindex = ["2_layer_CNN", "3_layer_CNN"]
    model_list, model_type_list = get_model_list(args.Collaborativeurl, modelsindex, models)
Liang Cheng's avatar
Liang Cheng committed

    print('Begin Private Training')

    private_model_private_dataset_train_losses = []
    for n, model in enumerate(model_list):
        print('train Local Model {} on Private Dataset'.format(n))
        model.to(device)
        model.train()
        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
Liang Cheng's avatar
Liang Cheng committed
                                        momentum=0.5)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
Liang Cheng's avatar
Liang Cheng committed
                                         weight_decay=1e-4)
Liang Cheng's avatar
Liang Cheng committed
        trainloader = DataLoader(DatasetSplit(train_dataset, list(user_groups[n])), batch_size=5, shuffle=True)
        criterion = nn.NLLLoss().to(device)
        train_epoch_losses = []
Liang Cheng's avatar
Liang Cheng committed
        for epoch in range(args.Communication_private_epoch):
            train_batch_losses = []
            for batch_idx, (images, labels) in enumerate(trainloader):
Liang Cheng's avatar
Liang Cheng committed
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
Liang Cheng's avatar
Liang Cheng committed
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
Liang Cheng's avatar
Liang Cheng committed
                if batch_idx % 5 == 0:
                    print('Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
Liang Cheng's avatar
Liang Cheng committed
                        n, model_type_list[n], epoch + 1, batch_idx * len(images), len(trainloader.dataset),
                                               100. * batch_idx / len(trainloader), loss.item()))
                train_batch_losses.append(loss.item())
Liang Cheng's avatar
Liang Cheng committed
            loss_avg = sum(train_batch_losses) / len(train_batch_losses)
            train_epoch_losses.append(loss_avg)
Liang Cheng's avatar
Liang Cheng committed
        torch.save(model.state_dict(),
                   'Src/CollaborativeModel/LocalModel{}Type{}.pkl'.format(n, model_type_list[n], args.epoch))
        private_model_private_dataset_train_losses.append(train_epoch_losses)
    plt.figure()
Liang Cheng's avatar
Liang Cheng committed
    for i, val in enumerate(private_model_private_dataset_train_losses):
        print(val)
Liang Cheng's avatar
Liang Cheng committed
        plt.plot(range(len(val)), val)
Liang Cheng's avatar
Liang Cheng committed
    plt.title('collaborative_private_model_private_dataset_train_losses')
    plt.xlabel('epoches')
    plt.ylabel('Train loss')
    plt.savefig('Src/Figure/collaborative_private_model_private_dataset_train_losses.png')
    plt.show()
    print('End Private Training')

Liang Cheng's avatar
Liang Cheng committed
from option import args_parser
if __name__ == '__main__':
    args = args_parser()
    collaborative_private_model_femnist_train(args)