Skip to content

Commit 46ee3a7

Browse files
authored
Merge pull request #585 from wizard1203/dev/v0.7.0
SP optimizers
2 parents b0e7ac2 + e38c98a commit 46ee3a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+5272
-35
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
```
2+
sh test_run_server.sh
3+
sh test_run_client.sh 1
4+
sh test_run_client.sh 2
5+
```
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
common_args:
2+
training_type: "cross_silo"
3+
scenario: "horizontal"
4+
using_mlops: false
5+
random_seed: 0
6+
7+
environment_args:
8+
bootstrap: config/bootstrap.sh
9+
10+
data_args:
11+
dataset: "mnist"
12+
data_cache_dir: ~/fedml_data
13+
partition_method: "hetero"
14+
partition_alpha: 0.5
15+
16+
model_args:
17+
model: "lr"
18+
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
19+
global_model_file_path: "./model_file_cache/global_model.pt"
20+
21+
train_args:
22+
federated_optimizer: "FedAvg"
23+
client_id_list:
24+
client_num_in_total: 1000
25+
client_num_per_round: 2
26+
comm_round: 5
27+
epochs: 1
28+
batch_size: 10
29+
client_optimizer: sgd
30+
learning_rate: 0.03
31+
weight_decay: 0.001
32+
33+
validation_args:
34+
frequency_of_the_test: 5
35+
36+
device_args:
37+
worker_num: 2
38+
using_gpu: false
39+
gpu_mapping_file: config/gpu_mapping.yaml
40+
gpu_mapping_key: mapping_default
41+
42+
comm_args:
43+
backend: "MQTT_S3"
44+
mqtt_config_path: config/mqtt_config.yaml
45+
s3_config_path: config/s3_config.yaml
46+
47+
tracking_args:
48+
log_file_dir: ./log
49+
enable_wandb: false
50+
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
51+
wandb_project: fedml
52+
wandb_name: fedml_torch_fedavg_mnist_lr
53+
54+
#lsa_args:
55+
# prime_number: 2 ** 15 - 19
56+
# precision_parameter: 10
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
MSG_TYPE_CONNECTION_IS_READY = 0
2+
MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS = "MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS"
3+
MSG_TYPE_NEIGHBOR_REPORT_NODE_STATUS = "MSG_TYPE_NEIGHBOR_REPORT_NODE_STATUS"
4+
MSG_TYPE_FLOW_FINISH = "MSG_TYPE_FLOW_FINISH"
5+
6+
MSG_ARG_KEY_TYPE = "msg_type"
7+
8+
PARAMS_KEY_SENDER_ID = "sender_id"
9+
PARAMS_KEY_RECEIVER_ID = "receiver_id"
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
3+
import fedml
4+
from fedml import FedMLRunner
5+
from fedml.core import FedMLExecutor, Params, FedMLAlgorithmFlow
6+
7+
8+
class Client(FedMLExecutor):
9+
def __init__(self, args):
10+
self.args = args
11+
id = args.rank
12+
neighbor_id_list = [0]
13+
super().__init__(id, neighbor_id_list)
14+
15+
self.device = None
16+
self.dataset = None
17+
self.model = None
18+
19+
def init(self, device, dataset, model):
20+
self.device = device
21+
self.dataset = dataset
22+
self.model = model
23+
24+
def local_training(self):
25+
logging.info("local_training start")
26+
params = self.get_params()
27+
model_params = params.get(Params.KEY_MODEL_PARAMS)
28+
return params
29+
30+
def handle_init_global_model(self):
31+
received_params = self.get_params()
32+
model_params = received_params.get(Params.KEY_MODEL_PARAMS)
33+
34+
params = Params()
35+
params.add(Params.KEY_MODEL_PARAMS, model_params)
36+
return params
37+
38+
39+
class Server(FedMLExecutor):
40+
def __init__(self, args):
41+
self.args = args
42+
id = args.rank
43+
neighbor_id_list = [1, 2]
44+
super().__init__(id, neighbor_id_list)
45+
46+
self.device = None
47+
self.dataset = None
48+
self.model = None
49+
50+
self.round_idx = 0
51+
52+
self.client_count = 0
53+
self.client_num = 2
54+
55+
def init(self, device, dataset, model):
56+
self.device = device
57+
self.dataset = dataset
58+
self.model = model
59+
60+
def init_global_model(self):
61+
logging.info("init_global_model")
62+
params = Params()
63+
params.add(Params.KEY_MODEL_PARAMS, self.model.state_dict())
64+
return params
65+
66+
def server_aggregate(self):
67+
logging.info("server_aggregate")
68+
params = self.get_params()
69+
model_params = params.get(Params.KEY_MODEL_PARAMS)
70+
# logging.info("value1 = {}".format(value1))
71+
self.round_idx += 1
72+
self.client_count += 1
73+
if self.client_count == self.client_num:
74+
self.client_count = 0
75+
params = Params()
76+
params.add(Params.KEY_MODEL_PARAMS, model_params)
77+
return params
78+
79+
def final_eval(self):
80+
logging.info("final_eval")
81+
82+
83+
if __name__ == "__main__":
84+
args = fedml.init()
85+
86+
# init device
87+
device = fedml.device.get_device(args)
88+
89+
# load data
90+
dataset, output_dim = fedml.data.load(args)
91+
92+
# load model
93+
model = fedml.model.create(args, output_dim)
94+
95+
if args.rank == 0:
96+
executor = Server(args)
97+
executor.init(device, dataset, model)
98+
else:
99+
executor = Client(args)
100+
executor.init(device, dataset, model)
101+
102+
fedml_alg_flow = FedMLAlgorithmFlow(args, executor)
103+
fedml_alg_flow.add_flow("init_global_model", Server.init_global_model)
104+
fedml_alg_flow.add_flow("handle_init", Client.handle_init_global_model)
105+
for round_idx in range(args.comm_round):
106+
fedml_alg_flow.add_flow("local_training", Client.local_training)
107+
fedml_alg_flow.add_flow("server_aggregate", Server.server_aggregate)
108+
fedml_alg_flow.add_flow("final_eval", Server.final_eval)
109+
fedml_alg_flow.build()
110+
111+
fedml_runner = FedMLRunner(args, device, dataset, model, algorithm_flow=fedml_alg_flow)
112+
fedml_runner.run()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
RANK=$1
3+
RUN_ID=$2
4+
python test_fedml_flow.py --cf fedml_config.yaml --rank $RANK --role client --run_id $RUN_ID
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
RUN_ID=$1
2+
python test_fedml_flow.py --cf fedml_config.yaml --rank 0 --role server --run_id $RUN_ID
Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

22

3-
mpirun -np 9 \
4-
-host "localhost:9" \
5-
/home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_femnist_2.yaml \
6-
--override_cmd_args
3+
# mpirun -np 9 \
4+
# -host "localhost:9" \
5+
# /home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_femnist_2.yaml \
6+
# --override_cmd_args
77

88

99

@@ -14,3 +14,27 @@ mpirun -np 9 \
1414

1515

1616

17+
18+
# mpirun -np 9 \
19+
# -host "localhost:9" \
20+
# /home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_stackoverflow.yaml \
21+
# --override_cmd_args
22+
23+
24+
# mpirun -np 9 \
25+
# -host "localhost:9" \
26+
# /home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_stackoverflow_2.yaml \
27+
# --override_cmd_args
28+
29+
30+
# mpirun -np 5 \
31+
# -host "localhost:5" \
32+
# /home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_reddit.yaml \
33+
# --override_cmd_args
34+
35+
36+
# mpirun -np 5 \
37+
# -host "localhost:5" \
38+
# /home/chaoyanghe/anaconda3/envs/fedml/bin/python main.py --cf config/schedule_reddit_2.yaml \
39+
# --override_cmd_args
40+

python/examples/simulation/mpi_torch_fedavg/config/schedule_femnist.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ train_args:
1717
federated_optimizer: "FedAvg_seq"
1818
client_id_list: "[]"
1919
client_num_in_total: 3400
20-
client_num_per_round: 1000
21-
comm_round: 500
20+
client_num_per_round: 100
21+
comm_round: 1000
2222
epochs: 10
2323
batch_size: 20
2424
client_optimizer: sgd
@@ -53,8 +53,12 @@ tracking_args:
5353
run_name: fedml_schedule_bench
5454
wandb_only_server: True
5555
using_mlops: False
56-
# simulation_schedule: "LinearFit-DP"
56+
simulation_schedule: "LinearFit-DP"
5757
# runtime_est_mode: "time_window" # EMA
58+
simulation_gpu_hetero: "ratio"
59+
gpu_hetero_ratio: 1.0
60+
# simulation_environment_hetero: "cos"
61+
# environment_hetero_ratio: 1.0
5862

5963
attack_args:
6064
enable_attack: false

python/examples/simulation/mpi_torch_fedavg/config/schedule_femnist_2.yaml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ model_args:
1414

1515

1616
train_args:
17-
# federated_optimizer: "FedAvg_seq"
18-
federated_optimizer: "FedOpt_seq"
17+
federated_optimizer: "FedAvg_seq"
1918
client_id_list: "[]"
2019
client_num_in_total: 3400
21-
client_num_per_round: 10
22-
comm_round: 500
23-
epochs: 1
20+
client_num_per_round: 100
21+
comm_round: 1000
22+
epochs: 10
2423
batch_size: 20
2524
client_optimizer: sgd
2625
learning_rate: 0.05
@@ -54,8 +53,12 @@ tracking_args:
5453
run_name: fedml_schedule_bench
5554
wandb_only_server: True
5655
using_mlops: False
57-
simulation_schedule: "LinearFit-DP"
58-
runtime_est_mode: "time_window" # EMA
56+
# simulation_schedule: "LinearFit-DP"
57+
# runtime_est_mode: "time_window" # EMA
58+
# simulation_gpu_hetero: "ratio"
59+
# gpu_hetero_ratio: 1.0
60+
# simulation_environment_hetero: "cos"
61+
# environment_hetero_ratio: 1.0
5962

6063
attack_args:
6164
enable_attack: false
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
common_args:
2+
training_type: "simulation"
3+
random_seed: 0
4+
5+
data_args:
6+
dataset: "reddit"
7+
data_cache_dir: "/home/chaoyanghe/FedScale/benchmark/dataset/data/reddit"
8+
data_map_file: "/home/chaoyanghe/FedScale//benchmark/dataset/data/reddit/client_data_mapping/train.csv"
9+
partition_method: "hetero"
10+
partition_alpha: 0.5
11+
filter_less: 21
12+
num_loaders: 0
13+
task: "nlp"
14+
block_size: 64
15+
mlm_probability: 0.15
16+
overwrite_cache: False
17+
num_class: 10000
18+
19+
model_args:
20+
# model: "cnn"
21+
# model: "rnn" # resnet18
22+
model: "albert-base-v2"
23+
24+
train_args:
25+
federated_optimizer: "FedAvg_seq"
26+
client_id_list: "[]"
27+
client_num_in_total: 3400
28+
client_num_per_round: 100
29+
comm_round: 405
30+
epochs: 5
31+
batch_size: 20
32+
client_optimizer: sgd
33+
learning_rate: 0.0005
34+
weight_decay: 0.0005
35+
lr_schedule: None
36+
37+
38+
validation_args:
39+
frequency_of_the_test: 400
40+
41+
device_args:
42+
worker_num: 4
43+
using_gpu: false
44+
gpu_mapping_file: config/gpu_mapping.yaml
45+
gpu_mapping_key: mapping_default
46+
# gpu_util_parse: "localhost:2,1,1,1,1,1,1,1"
47+
# gpu_util_parse: "localhost:2,1,1,1,0,0,0,0"
48+
gpu_util_parse: "localhost:2,1,0,1,1,0,0,0"
49+
50+
comm_args:
51+
backend: "MPI"
52+
is_mobile: 0
53+
54+
55+
tracking_args:
56+
log_file_dir: ./log
57+
enable_wandb: True
58+
wandb_entity: automl
59+
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408
60+
wandb_project: bench_optim
61+
wandb_name: fedml_optim_bench
62+
run_name: fedml_schedule_bench
63+
wandb_only_server: True
64+
using_mlops: False
65+
# simulation_schedule: "LinearFit-DP"
66+
# runtime_est_mode: "time_window" # EMA
67+
# simulation_gpu_hetero: "ratio"
68+
# gpu_hetero_ratio: 1.0
69+
70+
attack_args:
71+
enable_attack: false
72+
attack_type: None
73+
74+
defense_args:
75+
enable_defense: False
76+
defense_type: norm_diff_clipping
77+
norm_bound: 5.0
78+
79+
80+
81+

0 commit comments

Comments
 (0)