多方隐私计算纵向线性回归算法流程描述:核心流程包括样本对齐、模型训练、数据预测
1. 样本对齐:样本对齐使用基于RSA算法的安全求交,其流程及交互消息如下图所示。
源码:
Host侧:源码中添加日志
RsaIntersectionHost->run()
L175:
LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
L181:
LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
L184:
LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
L191:
LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
L196:
LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
L201:
LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
L205:
LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
L212:
LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
L215:
LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
Host侧:添加后的源码
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.secureprotol.encrypt import RsaEncrypt
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable
class RsaIntersectionHost(RsaIntersect):
def __init__(self, intersect_params):
super().__init__(intersect_params)
self.transfer_variable = RsaIntersectTransferVariable()
self.e = None
self.d = None
self.n = None
# parameter for intersection cache
self.is_version_match = False
self.has_cache_version = True
def cal_host_ids_process_pair(self, data_instances):
return data_instances.map(
lambda k, v: (
RsaIntersectionHost.hash(gmpy_math.powmod(int(RsaIntersectionHost.hash(k), 16), self.d, self.n)), k)
)
def generate_rsa_key(self, rsa_bit=1024):
encrypt_operator = RsaEncrypt()
encrypt_operator.generate_key(rsa_bit)
return encrypt_operator.get_key_pair()
def get_rsa_key(self):
if self.intersect_cache_param.use_cache:
LOGGER.info("Using intersection cache scheme, start to getting rsa key from cache.")
rsa_key = cache_utils.get_rsa_of_current_version(host_party_id=self.host_party_id,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za')
if rsa_key is not None:
e = int(rsa_key.get('rsa_e'))
d = int(rsa_key.get('rsa_d'))
n = int(rsa_key.get('rsa_n'))
else:
self.has_cache_version = False
LOGGER.info("Use cache but can not find any version in cache, set has_cache_version to false")
LOGGER.info("Start to generate rsa key")
e, d, n = self.generate_rsa_key()
else:
LOGGER.info("Not use cache, generate rsa keys.")
e, d, n = self.generate_rsa_key()
return e, d, n
def store_cache(self, host_id, rsa_key: dict, assign_version=None, assign_namespace=None):
store_cache_ret = cache_utils.store_cache(dtable=host_id,
guest_party_id=None,
host_party_id=self.host_party_id,
version=assign_version,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za',
namespace=assign_namespace)
LOGGER.info("Finish store host_ids_process to cache")
version = store_cache_ret.get('table_name')
namespace = store_cache_ret.get('namespace')
cache_utils.store_rsa(host_party_id=self.host_party_id,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za',
namespace=namespace,
version=version,
rsa=rsa_key
)
LOGGER.info("Finish store rsa key to cache")
return version, namespace
def host_ids_process(self, data_instances):
# (host_id_process, 1)
if self.intersect_cache_param.use_cache:
LOGGER.info("Use intersect cache.")
if self.has_cache_version:
current_version = cache_utils.host_get_current_verison(host_party_id=self.host_party_id,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za')
version = current_version.get('table_name')
namespace = current_version.get('namespace')
guest_current_version = self.transfer_variable.cache_version_info.get(0)
LOGGER.info("current_version:{}".format(current_version))
LOGGER.info("guest_current_version:{}".format(guest_current_version))
if guest_current_version.get('table_name') == version \
and guest_current_version.get('namespace') == namespace and \
current_version is not None:
self.is_version_match = True
else:
self.is_version_match = False
version_match_info = {'version_match': self.is_version_match,
'version': version,
'namespace': namespace}
self.transfer_variable.cache_version_match_info.remote(version_match_info,
role=consts.GUEST,
idx=0)
host_ids_process_pair = None
if not self.is_version_match or self.sync_intersect_ids:
# if self.sync_intersect_ids is true, host will get the encrypted intersect id from guest,
# which need the Za to decrypt them
LOGGER.info("read Za from cache")
host_ids_process_pair = session.table(name=version,
namespace=namespace,
create_if_missing=True,
error_if_exist=False)
if host_ids_process_pair.count() == 0:
host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
self.store_cache(host_ids_process_pair, rsa_key=rsa_key)
else:
self.is_version_match = False
LOGGER.info("is version_match:{}".format(self.is_version_match))
namespace = cache_utils.gen_cache_namespace(id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za',
host_party_id=self.host_party_id)
version = cache_utils.gen_cache_version(namespace=namespace,
create=True)
version_match_info = {'version_match': self.is_version_match,
'version': version,
'namespace': namespace}
self.transfer_variable.cache_version_match_info.remote(version_match_info,
role=consts.GUEST,
idx=0)
host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
self.store_cache(host_ids_process_pair, rsa_key=rsa_key, assign_version=version, assign_namespace=namespace)
LOGGER.info("remote version match info to guest")
else:
LOGGER.info("Not using cache, calculate Za using raw id")
host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
return host_ids_process_pair
def run(self, data_instances):
LOGGER.info("Start rsa intersection")
self.e, self.d, self.n = self.get_rsa_key()
LOGGER.info("Get rsa key!")
public_key = {"e": self.e, "n": self.n}
LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
self.transfer_variable.rsa_pubkey.remote(public_key,
role=consts.GUEST,
idx=0)
LOGGER.info("Remote public key to Guest.")
host_ids_process_pair = self.host_ids_process(data_instances)
LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
if self.intersect_cache_param.use_cache and not self.is_version_match or not self.intersect_cache_param.use_cache:
host_ids_process = host_ids_process_pair.mapValues(lambda v: 1)
LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
self.transfer_variable.intersect_host_ids_process.remote(host_ids_process,
role=consts.GUEST,
idx=0)
LOGGER.info("Remote host_ids_process to Guest.")
# Recv guest ids
LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
guest_ids = self.transfer_variable.intersect_guest_ids.get(idx=0)
LOGGER.info("Get guest_ids from guest")
# Process guest ids and return to guest
LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
guest_ids_process = guest_ids.map(lambda k, v: (k, gmpy_math.powmod(int(k), self.d, self.n)))
self.transfer_variable.intersect_guest_ids_process.remote(guest_ids_process,
role=consts.GUEST,
idx=0)
LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
LOGGER.info("Remote guest_ids_process to Guest.")
# recv intersect ids
LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
intersect_ids = None
if self.sync_intersect_ids:
encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
intersect_ids_pair = encrypt_intersect_ids.join(host_ids_process_pair, lambda e, h: h)
intersect_ids = intersect_ids_pair.map(lambda k, v: (v, "id"))
LOGGER.info("Get intersect ids from Guest")
LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
if not self.only_output_key:
intersect_ids = self._get_value_from_data(intersect_ids, data_instances)
LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
return intersect_ids
class RawIntersectionHost(RawIntersect):
def __init__(self, intersect_params):
super().__init__(intersect_params)
self.join_role = intersect_params.join_role
self.role = consts.HOST
def run(self, data_instances):
LOGGER.info("Start raw intersection")
if self.join_role == consts.GUEST:
intersect_ids = self.intersect_send_id(data_instances)
elif self.join_role == consts.HOST:
intersect_ids = self.intersect_join_id(data_instances)
else:
raise ValueError("Unknown intersect join role, please check the configure of host")
return intersect_ids
Guest侧:源码中添加日志
RsaIntersectionGuest->run()
L142:
LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
L149:
LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
L157:
LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
L161:
LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
L167:
LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
L173:
LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
L183:
LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
L206:
LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
L209:
LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
Guest侧:添加后的源码
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import Iterable
import gmpy2
import random
from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable
class RsaIntersectionGuest(RsaIntersect):
def __init__(self, intersect_params):
super().__init__(intersect_params)
self.random_bit = intersect_params.random_bit
self.e = None
self.n = None
self.transfer_variable = RsaIntersectTransferVariable()
# parameter for intersection cache
self.intersect_cache_param = intersect_params.intersect_cache_param
def map_raw_id_to_encrypt_id(self, raw_id_data, encrypt_id_data):
encrypt_id_data_exchange_kv = encrypt_id_data.map(lambda k, v: (v, k))
encrypt_raw_id = raw_id_data.join(encrypt_id_data_exchange_kv, lambda r, e: e)
encrypt_common_id = encrypt_raw_id.map(lambda k, v: (v, "id"))
return encrypt_common_id
def get_cache_version_match_info(self):
if self.intersect_cache_param.use_cache:
LOGGER.info("Use cache is true")
# check local cache version for each host
for i, host_party_id in enumerate(self.host_party_id_list):
current_version = cache_utils.guest_get_current_version(host_party_id=host_party_id,
guest_party_id=None,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag='Za'
)
LOGGER.info("host_id:{}, current_version:{}".format(host_party_id, current_version))
if current_version is None:
current_version = {"table_name": None, "namespace": None}
self.transfer_variable.cache_version_info.remote(current_version,
role=consts.HOST,
idx=i)
LOGGER.info("Remote current version to host:{}".format(host_party_id))
cache_version_match_info = self.transfer_variable.cache_version_match_info.get(idx=-1)
LOGGER.info("Get cache version match info:{}".format(cache_version_match_info))
else:
cache_version_match_info = None
LOGGER.info("Not using cache, cache_version_match_info is None")
return cache_version_match_info
def get_host_id_process(self, cache_version_match_info):
if self.intersect_cache_param.use_cache:
host_ids_process_list = [None for _ in self.host_party_id_list]
if isinstance(cache_version_match_info, Iterable):
for i, version_info in enumerate(cache_version_match_info):
if version_info.get('version_match'):
host_ids_process_list[i] = session.table(name=version_info.get('version'),
namespace=version_info.get('namespace'),
create_if_missing=True,
error_if_exist=False)
LOGGER.info("Read host {} 's host_ids_process from cache".format(self.host_party_id_list[i]))
else:
LOGGER.info("cache_version_match_info is not iterable, not use cache_version_match_info")
# check which host_id_process is not receive yet
host_ids_process_rev_idx_list = [ i for i, e in enumerate(host_ids_process_list) if e is None ]
if len(host_ids_process_rev_idx_list) > 0:
# Recv host_ids_process
# table(host_id_process, 1)
host_ids_process = []
for rev_idx in host_ids_process_rev_idx_list:
host_ids_process.append(self.transfer_variable.intersect_host_ids_process.get(idx=rev_idx))
LOGGER.info("Get host_ids_process from host {}".format(self.host_party_id_list[rev_idx]))
for i, host_idx in enumerate(host_ids_process_rev_idx_list):
host_ids_process_list[host_idx] = host_ids_process[i]
version = cache_version_match_info[host_idx].get('version')
namespace = cache_version_match_info[host_idx].get('namespace')
cache_utils.store_cache(dtable=host_ids_process[i],
guest_party_id=self.guest_party_id,
host_party_id=self.host_party_id_list[i],
version=version,
id_type=self.intersect_cache_param.id_type,
encrypt_type=self.intersect_cache_param.encrypt_type,
tag=consts.INTERSECT_CACHE_TAG,
namespace=namespace
)
LOGGER.info("Store host {}'s host_ids_process to cache.".format(self.host_party_id_list[host_idx]))
else:
host_ids_process_list = self.transfer_variable.intersect_host_ids_process.get(idx=-1)
LOGGER.info("Not using cache, get host_ids_process from all host")
return host_ids_process_list
@staticmethod
def guest_id_process(sid, random_bit, rsa_e, rsa_n):
r = random.SystemRandom().getrandbits(random_bit)
re_hash = gmpy_math.powmod(r, rsa_e, rsa_n) * int(RsaIntersectionGuest.hash(sid), 16) % rsa_n
return re_hash, (sid, r)
def run(self, data_instances):
LOGGER.info("Start rsa intersection")
public_keys = self.transfer_variable.rsa_pubkey.get(-1)
LOGGER.info("Get RSA public_key:{} from Host".format(public_keys))
self.e = [int(public_key["e"]) for public_key in public_keys]
self.n = [int(public_key["n"]) for public_key in public_keys]
cache_version_match_info = self.get_cache_version_match_info()
LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
# table (r^e % n * hash(sid), sid, r)
guest_id_process_list = [ data_instances.map(
lambda k, v: self.guest_id_process(k, random_bit=self.random_bit, rsa_e=self.e[i], rsa_n=self.n[i])) for i in range(len(self.e)) ]
# table(r^e % n *hash(sid), 1)
for i, guest_id in enumerate(guest_id_process_list):
LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
mask_guest_id = guest_id.mapValues(lambda v: 1)
self.transfer_variable.intersect_guest_ids.remote(mask_guest_id,
role=consts.HOST,
idx=i)
LOGGER.info("Remote guest_id to Host {}".format(i))
host_ids_process_list = self.get_host_id_process(cache_version_match_info)
LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
LOGGER.info("Get host_ids_process")
# Recv process guest ids
LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
# table(r^e % n *hash(sid), guest_id_process)
recv_guest_ids_process = self.transfer_variable.intersect_guest_ids_process.get(idx=-1)
LOGGER.info("Get guest_ids_process from Host")
LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
# table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
guest_ids_process_final = [v.join(recv_guest_ids_process[i], lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r), int(g[1]), self.n[i]))))
for i, v in enumerate(guest_id_process_list)]
# table(hash(guest_ids_process/r), sid))
LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
sid_guest_ids_process_final = [
g.map(lambda k, v: (v[1], v[0]))
for i, g in enumerate(guest_ids_process_final)]
# intersect table(hash(guest_ids_process/r), sid)
encrypt_intersect_ids = [v.join(host_ids_process_list[i], lambda sid, h: sid) for i, v in
enumerate(sid_guest_ids_process_final)]
if len(self.host_party_id_list) > 1:
LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
raw_intersect_ids = [e.map(lambda k, v: (v, 1)) for e in encrypt_intersect_ids]
intersect_ids = self.get_common_intersection(raw_intersect_ids)
# send intersect id
if self.sync_intersect_ids:
for i, host_party_id in enumerate(self.host_party_id_list):
remote_intersect_id = self.map_raw_id_to_encrypt_id(intersect_ids, encrypt_intersect_ids[i])
self.transfer_variable.intersect_ids.remote(remote_intersect_id,
role=consts.HOST,
idx=i)
LOGGER.info("Remote intersect ids to Host {}!".format(host_party_id))
else:
LOGGER.info("Not send intersect ids to Host!")
else:
intersect_ids = encrypt_intersect_ids[0]
if self.sync_intersect_ids:
remote_intersect_id = intersect_ids.mapValues(lambda v: 1)
self.transfer_variable.intersect_ids.remote(remote_intersect_id,
role=consts.HOST,
idx=0)
intersect_ids = intersect_ids.map(lambda k, v: (v, 1))
LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
LOGGER.info("Finish intersect_ids computing")
LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
if not self.only_output_key:
intersect_ids = self._get_value_from_data(intersect_ids, data_instances)
return intersect_ids
class RawIntersectionGuest(RawIntersect):
def __init__(self, intersect_params):
super().__init__(intersect_params)
self.role = consts.GUEST
self.join_role = intersect_params.join_role
def run(self, data_instances):
LOGGER.info("Start raw intersection")
if self.join_role == consts.HOST:
intersect_ids = self.intersect_send_id(data_instances)
elif self.join_role == consts.GUEST:
intersect_ids = self.intersect_join_id(data_instances)
else:
raise ValueError("Unknown intersect join role, please check the configure of guest")
return intersect_ids
2.模型训练:纵向线性回归算法模型训练中间参数的计算采用Paillier同态加密算法进行保护,并依赖中间方Arbiter进行收敛判断,其流程及交互消息如下图所示。
源码:
./federatedml/linear_model/base_linear_model_arbiter.py
HeteroBaseArbiter->fit()
L79:
LOGGER.info("iter:" + str(self.n_iter_))
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_base import BaseLinearModel
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator
from federatedml.util.validation_strategy import ValidationStrategy
class HeteroBaseArbiter(BaseLinearModel):
def __init__(self):
super(BaseLinearModel, self).__init__()
self.role = consts.ARBITER
# attribute
self.pre_loss = None
self.loss_history = []
self.cipher = paillier_cipher.Arbiter()
self.batch_generator = batch_generator.Arbiter()
self.gradient_loss_operator = None
self.converge_procedure = convergence.Arbiter()
self.best_iteration = -1
def perform_subtasks(self, **training_info):
"""
performs any tasks that the arbiter is responsible for.
This 'perform_subtasks' function serves as a handler on conducting any task that the arbiter is responsible
for. For example, for the 'perform_subtasks' function of 'HeteroDNNLRArbiter' class located in
'hetero_dnn_lr_arbiter.py', it performs some works related to updating/training local neural networks of guest
or host.
For this particular class, the 'perform_subtasks' function will do nothing. In other words, no subtask is
performed by this arbiter.
:param training_info: a dictionary holding training information
"""
pass
def init_validation_strategy(self, train_data=None, validate_data=None):
validation_strategy = ValidationStrategy(self.role, self.mode, self.validation_freqs,
self.early_stopping_rounds,
self.use_first_metric_only)
return validation_strategy
def fit(self, data_instances=None, validate_data=None):
"""
Train linear model of role arbiter
Parameters
----------
data_instances: DTable of Instance, input data
"""
LOGGER.info("Enter hetero linear model arbiter fit")
self.cipher_operator = self.cipher.paillier_keygen(self.model_param.encrypt_param.key_length)
self.batch_generator.initialize_batch_generator()
self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_num)
self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)
while self.n_iter_ < self.max_iter:
LOGGER.info("iter:" + str(self.n_iter_))
iter_loss = None
batch_data_generator = self.batch_generator.generate_batch_data()
total_gradient = None
self.optimizer.set_iters(self.n_iter_)
for batch_index in batch_data_generator:
# Compute and Transfer gradient info
gradient = self.gradient_loss_operator.compute_gradient_procedure(self.cipher_operator,
self.optimizer,
self.n_iter_,
batch_index)
if total_gradient is None:
total_gradient = gradient
else:
total_gradient = total_gradient + gradient
training_info = {"iteration": self.n_iter_, "batch_index": batch_index}
self.perform_subtasks(**training_info)
loss_list = self.gradient_loss_operator.compute_loss(self.cipher_operator, self.n_iter_, batch_index)
if len(loss_list) == 1:
if iter_loss is None:
iter_loss = loss_list[0]
else:
iter_loss += loss_list[0]
# LOGGER.info("Get loss from guest:{}".format(de_loss))
# if converge
if iter_loss is not None:
iter_loss /= self.batch_generator.batch_num
if self.need_call_back_loss:
self.callback_loss(self.n_iter_, iter_loss)
self.loss_history.append(iter_loss)
if self.model_param.early_stop == 'weight_diff':
# LOGGER.debug("total_gradient: {}".format(total_gradient))
weight_diff = fate_operator.norm(total_gradient)
# LOGGER.info("iter: {}, weight_diff:{}, is_converged: {}".format(self.n_iter_,
# weight_diff, self.is_converged))
if weight_diff < self.model_param.tol:
self.is_converged = True
else:
if iter_loss is None:
raise ValueError("Multiple host situation, loss early stop function is not available."
"You should use 'weight_diff' instead")
self.is_converged = self.converge_func.is_converge(iter_loss)
LOGGER.info("iter: {}, loss:{}, is_converged: {}".format(self.n_iter_, iter_loss, self.is_converged))
self.converge_procedure.sync_converge_info(self.is_converged, suffix=(self.n_iter_,))
if self.validation_strategy:
LOGGER.debug('Linear Arbiter running validation')
self.validation_strategy.validate(self, self.n_iter_)
if self.validation_strategy.need_stop():
LOGGER.debug('early stopping triggered')
self.best_iteration = self.n_iter_
break
self.n_iter_ += 1
if self.is_converged:
break
summary = {"loss_history": self.loss_history,
"is_converged": self.is_converged,
"best_iteration": self.best_iteration}
if self.validation_strategy and self.validation_strategy.has_saved_best_model():
self.load_model(self.validation_strategy.cur_best_model)
if self.loss_history is not None and len(self.loss_history) > 0:
summary["best_iter_loss"] = self.loss_history[self.best_iteration]
self.set_summary(summary)
LOGGER.debug("finish running linear model arbiter")
./federatedml/framework/hetero/sync/converge_sync.py
L19~20:
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
Arbiter->sync_converge_info()
L29:
LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))
L31:
LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.util import consts
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class Arbiter(object):
# noinspection PyAttributeOutsideInit
def _register_convergence(self, is_stopped_transfer):
self._is_stopped_transfer = is_stopped_transfer
def sync_converge_info(self, is_converged, suffix=tuple()):
self._is_stopped_transfer.remote(obj=is_converged, role=consts.HOST, idx=-1, suffix=suffix)
self._is_stopped_transfer.remote(obj=is_converged, role=consts.GUEST, idx=-1, suffix=suffix)
LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))
LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))
class _Client(object):
# noinspection PyAttributeOutsideInit
def _register_convergence(self, is_stopped_transfer):
self._is_stopped_transfer = is_stopped_transfer
def sync_converge_info(self, suffix=tuple()):
is_converged = self._is_stopped_transfer.get(idx=0, suffix=suffix)
return is_converged
Host = _Client
Guest = _Client
./federatedml/framework/hetero/sync/loss_sync.py
L21~22:
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
Guest->sync_loss_info()
L41:LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
Guest->get_host_loss_intermediate()
L45:LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Guest->get_host_loss_regular()
L50:
LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
Host->remote_loss_intermediate()
L62:
LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Host->remote_loss_regular()
L66:
LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from federatedml.util import consts
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class Arbiter(object):
def _register_loss_sync(self, loss_transfer):
self.loss_transfer = loss_transfer
def sync_loss_info(self, suffix=tuple()):
loss = self.loss_transfer.get(idx=0, suffix=suffix)
return loss
class Guest(object):
def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
self.host_loss_regular_transfer = host_loss_regular_transfer
self.loss_transfer = loss_transfer
self.loss_intermediate_transfer = loss_intermediate_transfer
def sync_loss_info(self, loss, suffix=tuple()):
self.loss_transfer.remote(loss, role=consts.ARBITER, idx=0, suffix=suffix)
LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
def get_host_loss_intermediate(self, suffix=tuple()):
loss_intermediate = self.loss_intermediate_transfer.get(idx=-1, suffix=suffix)
LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
return loss_intermediate
def get_host_loss_regular(self, suffix=tuple()):
losses = self.host_loss_regular_transfer.get(idx=-1, suffix=suffix)
LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
return losses
class Host(object):
def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
self.host_loss_regular_transfer = host_loss_regular_transfer
self.loss_transfer = loss_transfer
self.loss_intermediate_transfer = loss_intermediate_transfer
def remote_loss_intermediate(self, loss_intermediate, suffix=tuple()):
self.loss_intermediate_transfer.remote(obj=loss_intermediate, role=consts.GUEST, idx=0, suffix=suffix)
LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
def remote_loss_regular(self, loss_regular, suffix=tuple()):
self.host_loss_regular_transfer.remote(obj=loss_regular, role=consts.GUEST, idx=0, suffix=suffix)
LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))
./federatedml/framework/hetero/sync/paillier_keygen_sync.py
L20~21:
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
Arbiter->paillier_keygen()
L32~33:
priv_key = cipher.get_privacy_key()
LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
L35:
LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
L37:
LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.secureprotol.encrypt import PaillierEncrypt
from federatedml.util import consts
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class Arbiter(object):
# noinspection PyAttributeOutsideInit
def _register_paillier_keygen(self, pubkey_transfer):
self._pubkey_transfer = pubkey_transfer
def paillier_keygen(self, key_length, suffix=tuple()):
cipher = PaillierEncrypt()
cipher.generate_key(key_length)
pub_key = cipher.get_public_key()
priv_key = cipher.get_privacy_key()
LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
self._pubkey_transfer.remote(obj=pub_key, role=consts.HOST, idx=-1, suffix=suffix)
LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
self._pubkey_transfer.remote(obj=pub_key, role=consts.GUEST, idx=-1, suffix=suffix)
LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))
return cipher
class _Client(object):
# noinspection PyAttributeOutsideInit
def _register_paillier_keygen(self, pubkey_transfer):
self._pubkey_transfer = pubkey_transfer
def gen_paillier_cipher_operator(self, suffix=tuple()):
pubkey = self._pubkey_transfer.get(idx=0, suffix=suffix)
cipher = PaillierEncrypt()
cipher.set_public_key(pubkey)
return cipher
Host = _Client
Guest = _Client
./federatedml/secureprotol/fate_paillier.py
L24~25:
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
PaillierPublicKey
L71~72:
def __str__(self):
return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))
PaillierPublicKey->apply_obfuscator()
L80~83:
new_ciphertext = (ciphertext * obfuscator) % self.nsquare
LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
hex(ciphertext), hex(new_ciphertext), hex(r)))
return new_ciphertext
PaillierPublicKey->raw_encrypt()
L100~101:
LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
hex(plaintext), hex(ciphertext), self))
PaillierPrivateKey
L151~153:
def __str__(self):
return "PaillierPrivateKey: p={}, q={}".format(
hex(self.p), hex(self.q))
PaillierPrivateKey->raw_decrypt()
L191~194:
result = self.crt(mp, mq)
LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
hex(ciphertext), hex(result), self))
return result
PaillierEncryptedNumber
L231~233:
def __str__(self):
return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
self.public_key, hex(self.__ciphertext), hex(self.exponent))
"""Paillier encryption library for partially homomorphic encryption."""
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Mapping
from federatedml.secureprotol.fixedpoint import FixedPointNumber
from federatedml.secureprotol import gmpy_math
import random
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class PaillierKeypair(object):
def __init__(self):
pass
@staticmethod
def generate_keypair(n_length=1024):
"""return a new :class:`PaillierPublicKey` and :class:`PaillierPrivateKey`.
"""
p = q = n = None
n_len = 0
while n_len != n_length:
p = gmpy_math.getprimeover(n_length // 2)
q = p
while q == p:
q = gmpy_math.getprimeover(n_length // 2)
n = p * q
n_len = n.bit_length()
public_key = PaillierPublicKey(n)
private_key = PaillierPrivateKey(public_key, p, q)
return public_key, private_key
class PaillierPublicKey(object):
"""Contains a public key and associated encryption methods.
"""
def __init__(self, n):
self.g = n + 1
self.n = n
self.nsquare = n * n
self.max_int = n // 3 - 1
def __repr__(self):
hashcode = hex(hash(self))[2:]
return "<PaillierPublicKey {}>".format(hashcode[:10])
def __eq__(self, other):
return self.n == other.n
def __hash__(self):
return hash(self.n)
def __str__(self):
return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))
def apply_obfuscator(self, ciphertext, random_value=None):
"""
"""
r = random_value or random.SystemRandom().randrange(1, self.n)
obfuscator = gmpy_math.powmod(r, self.n, self.nsquare)
new_ciphertext = (ciphertext * obfuscator) % self.nsquare
LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
hex(ciphertext), hex(new_ciphertext), hex(r)))
return new_ciphertext
def raw_encrypt(self, plaintext, random_value=None):
"""
"""
if not isinstance(plaintext, int):
raise TypeError("plaintext should be int, but got: %s" %
type(plaintext))
if plaintext >= (self.n - self.max_int) and plaintext < self.n:
# Very large plaintext, take a sneaky shortcut using inverses
neg_plaintext = self.n - plaintext # = abs(plaintext - nsquare)
neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare
ciphertext = gmpy_math.invert(neg_ciphertext, self.nsquare)
else:
ciphertext = (self.n * plaintext + 1) % self.nsquare
LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
hex(plaintext), hex(ciphertext), self))
ciphertext = self.apply_obfuscator(ciphertext, random_value)
return ciphertext
def encrypt(self, value, precision=None, random_value=None):
"""Encode and Paillier encrypt a real number value.
"""
encoding = FixedPointNumber.encode(value, self.n, self.max_int, precision)
obfuscator = random_value or 1
ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator)
encryptednumber = PaillierEncryptedNumber(self, ciphertext, encoding.exponent)
if random_value is None:
encryptednumber.apply_obfuscator()
return encryptednumber
class PaillierPrivateKey(object):
"""Contains a private key and associated decryption method.
"""
def __init__(self, public_key, p, q):
if not p * q == public_key.n:
raise ValueError("given public key does not match the given p and q")
if p == q:
raise ValueError("p and q have to be different")
self.public_key = public_key
if q < p:
self.p = q
self.q = p
else:
self.p = p
self.q = q
self.psquare = self.p * self.p
self.qsquare = self.q * self.q
self.q_inverse = gmpy_math.invert(self.q, self.p)
self.hp = self.h_func(self.p, self.psquare)
self.hq = self.h_func(self.q, self.qsquare)
def __eq__(self, other):
return self.p == other.p and self.q == other.q
def __hash__(self):
return hash((self.p, self.q))
def __repr__(self):
hashcode = hex(hash(self))[2:]
return "<PaillierPrivateKey {}>".format(hashcode[:10])
def __str__(self):
return "PaillierPrivateKey: p={}, q={}".format(
hex(self.p), hex(self.q))
def h_func(self, x, xsquare):
"""Computes the h-function as defined in Paillier's paper page.
"""
return gmpy_math.invert(self.l_func(gmpy_math.powmod(self.public_key.g,
x - 1, xsquare), x), x)
def l_func(self, x, p):
"""computes the L function as defined in Paillier's paper.
"""
return (x - 1) // p
def crt(self, mp, mq):
"""the Chinese Remainder Theorem as needed for decryption.
return the solution modulo n=pq.
"""
u = (mp - mq) * self.q_inverse % self.p
x = (mq + (u * self.q)) % self.public_key.n
return x
def raw_decrypt(self, ciphertext):
"""return raw plaintext.
"""
if not isinstance(ciphertext, int):
raise TypeError("ciphertext should be an int, not: %s" %
type(ciphertext))
mp = self.l_func(gmpy_math.powmod(ciphertext,
self.p-1, self.psquare),
self.p) * self.hp % self.p
mq = self.l_func(gmpy_math.powmod(ciphertext,
self.q-1, self.qsquare),
self.q) * self.hq % self.q
result = self.crt(mp, mq)
LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
hex(ciphertext), hex(result), self))
return result
def decrypt(self, encrypted_number):
"""return the decrypted & decoded plaintext of encrypted_number.
"""
if not isinstance(encrypted_number, PaillierEncryptedNumber):
raise TypeError("encrypted_number should be an PaillierEncryptedNumber, \
not: %s" % type(encrypted_number))
if self.public_key != encrypted_number.public_key:
raise ValueError("encrypted_number was encrypted against a different key!")
encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
encoded = FixedPointNumber(encoded,
encrypted_number.exponent,
self.public_key.n,
self.public_key.max_int)
decrypt_value = encoded.decode()
return decrypt_value
class PaillierEncryptedNumber(object):
"""Represents the Paillier encryption of a float or int.
"""
def __init__(self, public_key, ciphertext, exponent=0):
self.public_key = public_key
self.__ciphertext = ciphertext
self.exponent = exponent
self.__is_obfuscator = False
if not isinstance(self.__ciphertext, int):
raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext))
if not isinstance(self.public_key, PaillierPublicKey):
raise TypeError("public_key should be a PaillierPublicKey, not: %s" % type(self.public_key))
def __str__(self):
return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
self.public_key, hex(self.__ciphertext), hex(self.exponent))
def ciphertext(self, be_secure=True):
"""return the ciphertext of the PaillierEncryptedNumber.
"""
if be_secure and not self.__is_obfuscator:
self.apply_obfuscator()
return self.__ciphertext
def apply_obfuscator(self):
"""ciphertext by multiplying by r ** n with random r
"""
self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext)
self.__is_obfuscator = True
def __add__(self, other):
if isinstance(other, PaillierEncryptedNumber):
return self.__add_encryptednumber(other)
else:
return self.__add_scalar(other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self + (other * -1)
def __rsub__(self, other):
return other + (self * -1)
def __rmul__(self, scalar):
return self.__mul__(scalar)
def __truediv__(self, scalar):
return self.__mul__(1 / scalar)
def __mul__(self, scalar):
"""return Multiply by an scalar(such as int, float)
"""
encode = FixedPointNumber.encode(scalar, self.public_key.n, self.public_key.max_int)
plaintext = encode.encoding
if plaintext < 0 or plaintext >= self.public_key.n:
raise ValueError("Scalar out of bounds: %i" % plaintext)
if plaintext >= self.public_key.n - self.public_key.max_int:
# Very large plaintext, play a sneaky trick using inverses
neg_c = gmpy_math.invert(self.ciphertext(False), self.public_key.nsquare)
neg_scalar = self.public_key.n - plaintext
ciphertext = gmpy_math.powmod(neg_c, neg_scalar, self.public_key.nsquare)
else:
ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.nsquare)
exponent = self.exponent + encode.exponent
return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)
def increase_exponent_to(self, new_exponent):
"""return PaillierEncryptedNumber:
new PaillierEncryptedNumber with same value but having great exponent.
"""
if new_exponent < self.exponent:
raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent))
factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent)
new_encryptednumber = self.__mul__(factor)
new_encryptednumber.exponent = new_exponent
return new_encryptednumber
def __align_exponent(self, x, y):
"""return x,y with same exponet
"""
if x.exponent < y.exponent:
x = x.increase_exponent_to(y.exponent)
elif x.exponent > y.exponent:
y = y.increase_exponent_to(x.exponent)
return x, y
def __add_scalar(self, scalar):
"""return PaillierEncryptedNumber: z = E(x) + y
"""
encoded = FixedPointNumber.encode(scalar,
self.public_key.n,
self.public_key.max_int,
max_exponent=self.exponent)
return self.__add_fixpointnumber(encoded)
def __add_fixpointnumber(self, encoded):
"""return PaillierEncryptedNumber: z = E(x) + FixedPointNumber(y)
"""
if self.public_key.n != encoded.n:
raise ValueError("Attempted to add numbers encoded against different public keys!")
# their exponents must match, and align.
x, y = self.__align_exponent(self, encoded)
encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1)
encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent)
return encryptednumber
def __add_encryptednumber(self, other):
"""return PaillierEncryptedNumber: z = E(x) + E(y)
"""
if self.public_key != other.public_key:
raise ValueError("add two numbers have different public key!")
# their exponents must match, and align.
x, y = self.__align_exponent(self, other)
encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent)
return encryptednumber
def __raw_add(self, e_x, e_y, exponent):
"""return the integer E(x + y) given ints E(x) and E(y).
"""
ciphertext = e_x * e_y % self.public_key.nsquare
return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)
./federatedml/secureprotol/fixedpoint.py
L22~23:
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
FixedPointNumber->encode()
L86~88:
result = cls(int_fixpoint % n, exponent, n, max_int)
LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
return result
FixedPointNumber->decode()
L105~107:
result = mantissa * pow(self.BASE, -self.exponent)
LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
return result
FixedPointNumber
L271~273:
def __str__(self):
return "FixedPointNumber: encoding={}, exponent={}".format(
hex(self.encoding), hex(self.exponent))
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import sys
import numpy as np
from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class FixedPointNumber(object):
"""Represents a float or int fixedpoit encoding;.
"""
BASE = 16
LOG2_BASE = math.log(BASE, 2)
FLOAT_MANTISSA_BITS = sys.float_info.mant_dig
Q = 293973345475167247070445277780365744413
def __init__(self, encoding, exponent, n=None, max_int=None):
self.n = n
self.max_int = max_int
if self.n is None:
self.n = self.Q
self.max_int = self.Q // 3 - 1
self.encoding = encoding
self.exponent = exponent
@classmethod
def encode(cls, scalar, n=None, max_int=None, precision=None, max_exponent=None):
"""return an encoding of an int or float.
"""
# Calculate the maximum exponent for desired precision
exponent = None
# Too low value preprocess;
# avoid "OverflowError: int too large to convert to float"
if np.abs(scalar) < 1e-200:
scalar = 0
if n is None:
n = cls.Q
max_int = cls.Q // 3 - 1
if precision is None:
if isinstance(scalar, int) or isinstance(scalar, np.int16) or \
isinstance(scalar, np.int32) or isinstance(scalar, np.int64):
exponent = 0
elif isinstance(scalar, float) or isinstance(scalar, np.float16) \
or isinstance(scalar, np.float32) or isinstance(scalar, np.float64):
flt_exponent = math.frexp(scalar)[1]
lsb_exponent = cls.FLOAT_MANTISSA_BITS - flt_exponent
exponent = math.floor(lsb_exponent / cls.LOG2_BASE)
else:
raise TypeError("Don't know the precision of type %s."
% type(scalar))
else:
exponent = math.floor(math.log(precision, cls.BASE))
if max_exponent is not None:
exponent = max(max_exponent, exponent)
int_fixpoint = int(round(scalar * pow(cls.BASE, exponent)))
if abs(int_fixpoint) > max_int:
raise ValueError('Integer needs to be within +/- %d but got %d'
% (max_int, int_fixpoint))
result = cls(int_fixpoint % n, exponent, n, max_int)
LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
return result
def decode(self):
"""return decode plaintext.
"""
if self.encoding >= self.n:
# Should be mod n
raise ValueError('Attempted to decode corrupted number')
elif self.encoding <= self.max_int:
# Positive
mantissa = self.encoding
elif self.encoding >= self.n - self.max_int:
# Negative
mantissa = self.encoding - self.n
else:
raise OverflowError('Overflow detected in decode number')
result = mantissa * pow(self.BASE, -self.exponent)
LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
return result
def increase_exponent_to(self, new_exponent):
"""return FixedPointNumber: new encoding with same value but having great exponent.
"""
if new_exponent < self.exponent:
raise ValueError('New exponent %i should be greater than'
'old exponent %i' % (new_exponent, self.exponent))
factor = pow(self.BASE, new_exponent - self.exponent)
new_encoding = self.encoding * factor % self.n
return FixedPointNumber(new_encoding, new_exponent, self.n, self.max_int)
def __align_exponent(self, x, y):
"""return x,y with same exponet
"""
if x.exponent < y.exponent:
x = x.increase_exponent_to(y.exponent)
elif x.exponent > y.exponent:
y = y.increase_exponent_to(x.exponent)
return x, y
def __truncate(self, a):
scalar = a.decode()
return FixedPointNumber.encode(scalar)
def __add__(self, other):
if isinstance(other, FixedPointNumber):
return self.__add_fixpointnumber(other)
else:
return self.__add_scalar(other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, FixedPointNumber):
return self.__sub_fixpointnumber(other)
else:
return self.__sub_scalar(other)
def __rsub__(self, other):
x = self.__sub__(other)
x = -1 * x.decode()
return self.encode(x)
def __rmul__(self, other):
return self.__mul__(other)
def __mul__(self, other):
if isinstance(other, FixedPointNumber):
return self.__mul_fixpointnumber(other)
else:
return self.__mul_scalar(other)
def __truediv__(self, other):
if isinstance(other, FixedPointNumber):
scalar = other.decode()
else:
scalar = other
return self.__mul__(1 / scalar)
def __rtruediv__(self, other):
res = 1.0 / self.__truediv__(other).decode()
return FixedPointNumber.encode(res)
def __lt__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x < y:
return True
else:
return False
def __gt__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x > y:
return True
else:
return False
def __le__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x <= y:
return True
else:
return False
def __ge__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x >= y:
return True
else:
return False
def __eq__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x == y:
return True
else:
return False
def __ne__(self, other):
x = self.decode()
if isinstance(other, FixedPointNumber):
y = other.decode()
else:
y = other
if x != y:
return True
else:
return False
def __add_fixpointnumber(self, other):
x, y = self.__align_exponent(self, other)
encoding = (x.encoding + y.encoding) % self.Q
return FixedPointNumber(encoding, x.exponent)
def __add_scalar(self, scalar):
encoded = self.encode(scalar)
return self.__add_fixpointnumber(encoded)
def __sub_fixpointnumber(self, other):
scalar = -1 * other.decode()
return self.__add_scalar(scalar)
def __sub_scalar(self, scalar):
scalar = -1 * scalar
return self.__add_scalar(scalar)
def __mul_fixpointnumber(self, other):
encoding = (self.encoding * other.encoding) % self.Q
exponet = self.exponent + other.exponent
mul_fixedpoint = FixedPointNumber(encoding, exponet)
truncate_mul_fixedpoint = self.__truncate(mul_fixedpoint)
return truncate_mul_fixedpoint
def __mul_scalar(self, scalar):
encoded = self.encode(scalar)
return self.__mul_fixpointnumber(encoded)
def __str__(self):
return "FixedPointNumber: encoding={}, exponent={}".format(
hex(self.encoding), hex(self.exponent))
./federatedml/optim/gradient/hetero_linear_model_gradient.py
Guest->compute_gradient_procedure()
L178:
LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
L181:
LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Guest->get_host_forward()
L188:
LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Guest->remote_fore_gradient()
L193:
LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->update_gradient()
L197:
LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
L199:
LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
Host->compute_gradient_procedure()
L234:
LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
L236:
LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))
L244:
LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
L247:
LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Host->remote_host_forward()
L278:
LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Host->get_fore_gradient()
L282:
LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
Host->update_gradient()
L287:
LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
L289:
LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import numpy as np
import scipy.sparse as sp
from federatedml.feature.sparse_vector import SparseVector
from federatedml.statistic import data_overview
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator
def __compute_partition_gradient(data, fit_intercept=True, is_sparse=False):
"""
Compute hetero regression gradient for:
gradient = ∑d*x, where d is fore_gradient which differ from different algorithm
Parameters
----------
data: DTable, include fore_gradient and features
fit_intercept: bool, if model has interception or not. Default True
Returns
----------
numpy.ndarray
hetero regression model gradient
"""
feature = []
fore_gradient = []
if is_sparse:
row_indice = []
col_indice = []
data_value = []
row = 0
feature_shape = None
for key, (sparse_features, d) in data:
fore_gradient.append(d)
assert isinstance(sparse_features, SparseVector)
if feature_shape is None:
feature_shape = sparse_features.get_shape()
for idx, v in sparse_features.get_all_data():
col_indice.append(idx)
row_indice.append(row)
data_value.append(v)
row += 1
if feature_shape is None or feature_shape == 0:
return 0
sparse_matrix = sp.csr_matrix((data_value, (row_indice, col_indice)), shape=(row, feature_shape))
fore_gradient = np.array(fore_gradient)
# gradient = sparse_matrix.transpose().dot(fore_gradient).tolist()
gradient = fate_operator.dot(sparse_matrix.transpose(), fore_gradient).tolist()
if fit_intercept:
bias_grad = np.sum(fore_gradient)
gradient.append(bias_grad)
# LOGGER.debug("In first method, gradient: {}, bias_grad: {}".format(gradient, bias_grad))
return np.array(gradient)
else:
for key, value in data:
feature.append(value[0])
fore_gradient.append(value[1])
feature = np.array(feature)
fore_gradient = np.array(fore_gradient)
if feature.shape[0] <= 0:
return 0
gradient = fate_operator.dot(feature.transpose(), fore_gradient)
gradient = gradient.tolist()
if fit_intercept:
bias_grad = np.sum(fore_gradient)
gradient.append(bias_grad)
return np.array(gradient)
def compute_gradient(data_instances, fore_gradient, fit_intercept):
"""
Compute hetero-regression gradient
Parameters
----------
data_instances: DTable, input data
fore_gradient: DTable, fore_gradient
fit_intercept: bool, if model has intercept or not
Returns
----------
DTable
the hetero regression model's gradient
"""
feat_join_grad = data_instances.join(fore_gradient,
lambda d, g: (d.features, g))
is_sparse = data_overview.is_sparse_data(data_instances)
f = functools.partial(__compute_partition_gradient,
fit_intercept=fit_intercept,
is_sparse=is_sparse)
gradient_partition = feat_join_grad.applyPartitions(f)
gradient_partition = gradient_partition.reduce(lambda x, y: x + y)
gradient = gradient_partition / data_instances.count()
return gradient
class HeteroGradientBase(object):
def compute_gradient_procedure(self, *args):
raise NotImplementedError("Should not call here")
def set_total_batch_nums(self, total_batch_nums):
"""
Use for sqn gradient.
"""
pass
class Guest(HeteroGradientBase):
def __init__(self):
self.host_forwards = None
self.forwards = None
self.aggregated_forwards = None
def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
guest_gradient_transfer, guest_optim_gradient_transfer):
self.host_forward_transfer = host_forward_transfer
self.fore_gradient_transfer = fore_gradient_transfer
self.unilateral_gradient_transfer = guest_gradient_transfer
self.unilateral_optim_gradient_transfer = guest_optim_gradient_transfer
def compute_and_aggregate_forwards(self, data_instances, model_weights,
encrypted_calculator, batch_index, current_suffix, offset=None):
raise NotImplementedError("Function should not be called here")
def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights, optimizer,
n_iter_, batch_index, offset=None):
"""
Linear model gradient procedure
Step 1: get host forwards which differ from different algorithm
For Logistic Regression and Linear Regression: forwards = wx
For Poisson Regression, forwards = exp(wx)
Step 2: Compute self forwards and aggregate host forwards and get d = fore_gradient
Step 3: Compute unilateral gradient = ∑d*x,
Step 4: Send unilateral gradients to arbiter and received the optimized and decrypted gradient.
"""
current_suffix = (n_iter_, batch_index)
# self.host_forwards = self.get_host_forward(suffix=current_suffix)
fore_gradient = self.compute_and_aggregate_forwards(data_instances, model_weights, encrypted_calculator,
batch_index, current_suffix, offset)
self.remote_fore_gradient(fore_gradient, suffix=current_suffix)
unilateral_gradient = compute_gradient(data_instances,
fore_gradient,
model_weights.fit_intercept)
if optimizer is not None:
unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
return optimized_gradient, fore_gradient, self.host_forwards
def get_host_forward(self, suffix=tuple()):
host_forward = self.host_forward_transfer.get(idx=-1, suffix=suffix)
LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
return host_forward
def remote_fore_gradient(self, fore_gradient, suffix=tuple()):
self.fore_gradient_transfer.remote(obj=fore_gradient, role=consts.HOST, idx=-1, suffix=suffix)
LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))
def update_gradient(self, unilateral_gradient, suffix=tuple()):
self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
return optimized_gradient
class Host(HeteroGradientBase):
def __init__(self):
self.forwards = None
self.fore_gradient = None
def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
host_gradient_transfer, host_optim_gradient_transfer):
self.host_forward_transfer = host_forward_transfer
self.fore_gradient_transfer = fore_gradient_transfer
self.unilateral_gradient_transfer = host_gradient_transfer
self.unilateral_optim_gradient_transfer = host_optim_gradient_transfer
def compute_forwards(self, data_instances, model_weights):
raise NotImplementedError("Function should not be called here")
def compute_unilateral_gradient(self, data_instances, fore_gradient, model_weights, optimizer):
raise NotImplementedError("Function should not be called here")
def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights,
optimizer,
n_iter_, batch_index):
"""
Linear model gradient procedure
Step 1: get host forwards which differ from different algorithm
For Logistic Regression: forwards = wx
"""
current_suffix = (n_iter_, batch_index)
self.forwards = self.compute_forwards(data_instances, model_weights)
LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
encrypted_forward = encrypted_calculator[batch_index].encrypt(self.forwards)
LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))
self.remote_host_forward(encrypted_forward, suffix=current_suffix)
fore_gradient = self.get_fore_gradient(suffix=current_suffix)
unilateral_gradient = compute_gradient(data_instances,
fore_gradient,
model_weights.fit_intercept)
LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
if optimizer is not None:
unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
return optimized_gradient, fore_gradient
def compute_sqn_forwards(self, data_instances, delta_s, cipher_operator):
"""
To compute Hessian matrix, y, s are needed.
g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
define forward_hess = ∑(0.25 * x * s)
"""
sqn_forwards = data_instances.mapValues(
lambda v: cipher_operator.encrypt(fate_operator.vec_dot(v.features, delta_s.coef_) + delta_s.intercept_))
# forward_sum = sqn_forwards.reduce(reduce_add)
return sqn_forwards
def compute_forward_hess(self, data_instances, delta_s, forward_hess):
"""
To compute Hessian matrix, y, s are needed.
g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
define forward_hess = (0.25 * x * s)
"""
hess_vector = compute_gradient(data_instances,
forward_hess,
delta_s.fit_intercept)
return np.array(hess_vector)
def remote_host_forward(self, host_forward, suffix=tuple()):
self.host_forward_transfer.remote(obj=host_forward, role=consts.GUEST, idx=0, suffix=suffix)
LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
def get_fore_gradient(self, suffix=tuple()):
host_forward = self.fore_gradient_transfer.get(idx=0, suffix=suffix)
LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
return host_forward
def update_gradient(self, unilateral_gradient, suffix=tuple()):
self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))
return optimized_gradient
class Arbiter(HeteroGradientBase):
def __init__(self):
self.has_multiple_hosts = False
def _register_gradient_sync(self, guest_gradient_transfer, host_gradient_transfer,
guest_optim_gradient_transfer, host_optim_gradient_transfer):
self.guest_gradient_transfer = guest_gradient_transfer
self.host_gradient_transfer = host_gradient_transfer
self.guest_optim_gradient_transfer = guest_optim_gradient_transfer
self.host_optim_gradient_transfer = host_optim_gradient_transfer
def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
"""
Compute gradients.
Received local_gradients from guest and hosts. Merge and optimize, then separate and remote back.
Parameters
----------
cipher_operator: Use for encryption
optimizer: optimizer that get delta gradient of this iter
n_iter_: int, current iter nums
batch_index: int, use to obtain current encrypted_calculator
"""
current_suffix = (n_iter_, batch_index)
host_gradients, guest_gradient = self.get_local_gradient(current_suffix)
if len(host_gradients) > 1:
self.has_multiple_hosts = True
host_gradients = [np.array(h) for h in host_gradients]
guest_gradient = np.array(guest_gradient)
size_list = [h_g.shape[0] for h_g in host_gradients]
size_list.append(guest_gradient.shape[0])
gradient = np.hstack((h for h in host_gradients))
gradient = np.hstack((gradient, guest_gradient))
grad = np.array(cipher_operator.decrypt_list(gradient))
# LOGGER.debug("In arbiter compute_gradient_procedure, before apply grad: {}, size_list: {}".format(
# grad, size_list
# ))
delta_grad = optimizer.apply_gradients(grad)
# LOGGER.debug("In arbiter compute_gradient_procedure, delta_grad: {}".format(
# delta_grad
# ))
separate_optim_gradient = self.separate(delta_grad, size_list)
# LOGGER.debug("In arbiter compute_gradient_procedure, separated gradient: {}".format(
# separate_optim_gradient
# ))
host_optim_gradients = separate_optim_gradient[: -1]
guest_optim_gradient = separate_optim_gradient[-1]
self.remote_local_gradient(host_optim_gradients, guest_optim_gradient, current_suffix)
return delta_grad
@staticmethod
def separate(value, size_list):
"""
Separate value in order to several set according size_list
Parameters
----------
value: list or ndarray, input data
size_list: list, each set size
Returns
----------
list
set after separate
"""
separate_res = []
cur = 0
for size in size_list:
separate_res.append(value[cur:cur + size])
cur += size
return separate_res
def get_local_gradient(self, suffix=tuple()):
host_gradients = self.host_gradient_transfer.get(idx=-1, suffix=suffix)
LOGGER.info("Get host_gradient from Host")
guest_gradient = self.guest_gradient_transfer.get(idx=0, suffix=suffix)
LOGGER.info("Get guest_gradient from Guest")
return host_gradients, guest_gradient
def remote_local_gradient(self, host_optim_gradients, guest_optim_gradient, suffix=tuple()):
for idx, host_optim_gradient in enumerate(host_optim_gradients):
self.host_optim_gradient_transfer.remote(host_optim_gradient,
role=consts.HOST,
idx=idx,
suffix=suffix)
self.guest_optim_gradient_transfer.remote(guest_optim_gradient,
role=consts.GUEST,
idx=0,
suffix=suffix)
./federatedml/optim/gradient/hetero_linr_gradient_and_loss.py
Guest->compute_and_aggregate_forwards()
L65:
LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
L67:
LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
L71:
LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
L73:
LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->compute_loss()
L98:
LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
L100:
LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
L103:
LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
L105:
LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
L108:
LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
Host->compute_loss()
L155:
LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
L157:
LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
L162:
LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
L164:
LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from federatedml.framework.hetero.sync import loss_sync
from federatedml.optim.gradient import hetero_linear_model_gradient
from federatedml.util import LOGGER
from federatedml.util.fate_operator import reduce_add, vec_dot
class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):
def register_gradient_procedure(self, transfer_variables):
self._register_gradient_sync(transfer_variables.host_forward,
transfer_variables.fore_gradient,
transfer_variables.guest_gradient,
transfer_variables.guest_optim_gradient)
self._register_loss_sync(transfer_variables.host_loss_regular,
transfer_variables.loss,
transfer_variables.loss_intermediate)
def compute_and_aggregate_forwards(self, data_instances, model_weights,
encrypted_calculator, batch_index, current_suffix, offset=None):
"""
Compute gradients:
gradient = (1/N)*\sum(wx -y)*x
Define wx as guest_forward or host_forward
Define (wx-y) as fore_gradient
Parameters
----------
data_instances: DTable of Instance, input data
model_weights: LinearRegressionWeights
Stores coef_ and intercept_ of model
encrypted_calculator: Use for different encrypted methods
offset: Used in Poisson only.
batch_index: int, use to obtain current encrypted_calculator index:
current_suffix: tuple or string. Used in transfer_variable
"""
wx = data_instances.mapValues(
lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
self.forwards = wx
LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
self.aggregated_forwards = encrypted_calculator[batch_index].encrypt(wx)
LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
self.host_forwards = self.get_host_forward(suffix=current_suffix)
for host_forward in self.host_forwards:
LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
self.aggregated_forwards = self.aggregated_forwards.join(host_forward, lambda g, h: g + h)
LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
fore_gradient = self.aggregated_forwards.join(data_instances, lambda wx, d: wx - d.label)
return fore_gradient
def compute_loss(self, data_instances, n_iter_, batch_index, loss_norm=None):
'''
Compute hetero linr loss:
loss = (1/N)*\sum(wx-y)^2 where y is label, w is model weight and x is features
(wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + wx_g - y)
'''
current_suffix = (n_iter_, batch_index)
n = data_instances.count()
loss_list = []
host_wx_squares = self.get_host_loss_intermediate(current_suffix)
if loss_norm is not None:
host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
else:
host_loss_regular = []
if len(self.host_forwards) > 1:
LOGGER.info("More than one host exist, loss is not available")
else:
host_forward = self.host_forwards[0]
host_wx_square = host_wx_squares[0]
wxy = self.forwards.join(data_instances, lambda wx, d: wx - d.label)
LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
wxy_square = wxy.mapValues(lambda x: np.square(x)).reduce(reduce_add)
LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
loss_gh = wxy.join(host_forward, lambda g, h: g * h).reduce(reduce_add)
LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
loss = (wxy_square + host_wx_square + 2 * loss_gh) / (2 * n)
LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
if loss_norm is not None:
loss = loss + loss_norm + host_loss_regular[0]
LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
loss_list.append(loss)
LOGGER.debug("In compute_loss, loss list are: {}".format(loss_list))
self.sync_loss_info(loss_list, suffix=current_suffix)
def compute_forward_hess(self, data_instances, delta_s, host_forwards):
"""
To compute Hessian matrix, y, s are needed.
g = (1/N)*∑(wx - y) * x
y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(x * s) * x
define forward_hess = (1/N)*∑(x * s)
"""
forwards = data_instances.mapValues(
lambda v: (np.dot(v.features, delta_s.coef_) + delta_s.intercept_))
for host_forward in host_forwards:
forwards = forwards.join(host_forward, lambda g, h: g + h)
hess_vector = hetero_linear_model_gradient.compute_gradient(data_instances,
forwards,
delta_s.fit_intercept)
return forwards, np.array(hess_vector)
class Host(hetero_linear_model_gradient.Host, loss_sync.Host):
def register_gradient_procedure(self, transfer_variables):
self._register_gradient_sync(transfer_variables.host_forward,
transfer_variables.fore_gradient,
transfer_variables.host_gradient,
transfer_variables.host_optim_gradient)
self._register_loss_sync(transfer_variables.host_loss_regular,
transfer_variables.loss,
transfer_variables.loss_intermediate)
def compute_forwards(self, data_instances, model_weights):
wx = data_instances.mapValues(lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
return wx
def compute_loss(self, model_weights, optimizer, n_iter_, batch_index, cipher_operator):
'''
Compute htero linr loss for:
loss = (1/2N)*\sum(wx-y)^2 where y is label, w is model weight and x is features
Note: (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + (wx_g - y))
'''
current_suffix = (n_iter_, batch_index)
self_wx_square = self.forwards.mapValues(lambda x: np.square(x)).reduce(reduce_add)
LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
en_wx_square = cipher_operator.encrypt(self_wx_square)
LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
self.remote_loss_intermediate(en_wx_square, suffix=current_suffix)
loss_regular = optimizer.loss_norm(model_weights)
if loss_regular is not None:
LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
en_loss_regular = cipher_operator.encrypt(loss_regular)
LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))
self.remote_loss_regular(en_loss_regular, suffix=current_suffix)
class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
def register_gradient_procedure(self, transfer_variables):
self._register_gradient_sync(transfer_variables.guest_gradient,
transfer_variables.host_gradient,
transfer_variables.guest_optim_gradient,
transfer_variables.host_optim_gradient)
self._register_loss_sync(transfer_variables.loss)
def compute_loss(self, cipher, n_iter_, batch_index):
"""
Decrypt loss from guest
"""
current_suffix = (n_iter_, batch_index)
loss_list = self.sync_loss_info(suffix=current_suffix)
de_loss_list = cipher.decrypt_list(loss_list)
return de_loss_list
3.数据预测:使用F纵向线性回归算法模型训练结果进行预测比较简单,中间参数也未采用加密手段进行保护,其流程及交互消息如下图所示。
源码:
在源码中添加日志:
./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_guest.py
HeteroLinRGuest->fit()
L67:
LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRGuest->predict()
L148:
LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
L150:
LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
L157:
LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util.io_check import assert_io_num_rows_equal
class HeteroLinRGuest(HeteroLinRBase):
def __init__(self):
super().__init__()
self.data_batch_count = []
# self.guest_forward = None
self.role = consts.GUEST
self.cipher = paillier_cipher.Guest()
self.batch_generator = batch_generator.Guest()
self.gradient_loss_operator = hetero_linr_gradient_and_loss.Guest()
self.converge_procedure = convergence.Guest()
self.encrypted_calculator = None
@staticmethod
def load_data(data_instance):
"""
return data_instance as original
Parameters
----------
data_instance: DTable of Instance, input data
"""
return data_instance
def fit(self, data_instances, validate_data=None):
"""
Train linR model of role guest
Parameters
----------
data_instances: DTable of Instance, input data
"""
LOGGER.info("Enter hetero_linR_guest fit")
self._abnormal_detection(data_instances)
self.header = self.get_header(data_instances)
self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)
self.cipher_operator = self.cipher.gen_paillier_cipher_operator()
LOGGER.info("Generate mini-batch from input data")
LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
self.batch_generator.initialize_batch_generator(data_instances, self.batch_size)
self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)
self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
self.encrypted_mode_calculator_param.mode,
self.encrypted_mode_calculator_param.re_encrypted_rate) for _
in range(self.batch_generator.batch_nums)]
LOGGER.info("Start initialize model.")
LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept))
model_shape = self.get_features_shape(data_instances)
w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)
while self.n_iter_ < self.max_iter:
LOGGER.info("iter:{}".format(self.n_iter_))
# each iter will get the same batch_data_generator
batch_data_generator = self.batch_generator.generate_batch_data()
self.optimizer.set_iters(self.n_iter_)
batch_index = 0
for batch_data in batch_data_generator:
# transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
batch_feat_inst = self.transform(batch_data)
# Start gradient procedure
optim_guest_gradient, _, _ = self.gradient_loss_operator.compute_gradient_procedure(
batch_feat_inst,
self.encrypted_calculator,
self.model_weights,
self.optimizer,
self.n_iter_,
batch_index
)
loss_norm = self.optimizer.loss_norm(self.model_weights)
self.gradient_loss_operator.compute_loss(data_instances, self.n_iter_, batch_index, loss_norm)
self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient)
batch_index += 1
# LOGGER.debug(
# "model_weights, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed))
self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))
# LOGGER.debug("model weights is {}".format(self.model_weights.coef_))
if self.validation_strategy:
LOGGER.debug('LinR guest running validation')
self.validation_strategy.validate(self, self.n_iter_)
if self.validation_strategy.need_stop():
LOGGER.debug('early stopping triggered')
break
self.n_iter_ += 1
if self.is_converged:
break
if self.validation_strategy and self.validation_strategy.has_saved_best_model():
self.load_model(self.validation_strategy.cur_best_model)
self.set_summary(self.get_model_summary())
@assert_io_num_rows_equal
def predict(self, data_instances):
"""
Prediction of linR
Parameters
----------
data_instances: DTable of Instance, input data
predict_param: PredictParam, the setting of prediction.
Returns
----------
DTable
include input data label, predict results
"""
LOGGER.info("Start predict ...")
self._abnormal_detection(data_instances)
data_instances = self.align_data_header(data_instances, self.header)
data_features = self.transform(data_instances)
pred = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
host_preds = self.transfer_variable.host_partial_prediction.get(idx=-1)
LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
LOGGER.info("Get prediction from Host")
for host_pred in host_preds:
pred = pred.join(host_pred, lambda g, h: g + h)
# predict_result = data_instances.join(pred, lambda d, pred: [d.label, pred, pred, {"label": pred}])
predict_result = self.predict_score_to_output(data_instances=data_instances, predict_score=pred,
classes=None)
LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))
return predict_result
./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_host.py
HeteroLinRHost->fit()
L58:
LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRHost->predict()
L131:
LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
L133:
LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts
class HeteroLinRHost(HeteroLinRBase):
def __init__(self):
super(HeteroLinRHost, self).__init__()
self.batch_num = None
self.batch_index_list = []
self.role = consts.HOST
self.cipher = paillier_cipher.Host()
self.batch_generator = batch_generator.Host()
self.gradient_loss_operator = hetero_linr_gradient_and_loss.Host()
self.converge_procedure = convergence.Host()
self.encrypted_calculator = None
def fit(self, data_instances, validate_data=None):
"""
Train linear regression model of role host
Parameters
----------
data_instances: DTable of Instance, input data
"""
LOGGER.info("Enter hetero_linR host")
self._abnormal_detection(data_instances)
self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)
self.header = self.get_header(data_instances)
self.cipher_operator = self.cipher.gen_paillier_cipher_operator()
self.batch_generator.initialize_batch_generator(data_instances)
self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)
LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
self.encrypted_mode_calculator_param.mode,
self.encrypted_mode_calculator_param.re_encrypted_rate) for _
in range(self.batch_generator.batch_nums)]
LOGGER.info("Start initialize model.")
model_shape = self.get_features_shape(data_instances)
if self.init_param_obj.fit_intercept:
self.init_param_obj.fit_intercept = False
w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)
while self.n_iter_ < self.max_iter:
LOGGER.info("iter:" + str(self.n_iter_))
self.optimizer.set_iters(self.n_iter_)
batch_data_generator = self.batch_generator.generate_batch_data()
batch_index = 0
for batch_data in batch_data_generator:
batch_feat_inst = self.transform(batch_data)
optim_host_gradient, _ = self.gradient_loss_operator.compute_gradient_procedure(
batch_feat_inst,
self.encrypted_calculator,
self.model_weights,
self.optimizer,
self.n_iter_,
batch_index)
self.gradient_loss_operator.compute_loss(self.model_weights, self.optimizer, self.n_iter_, batch_index,
self.cipher_operator)
self.model_weights = self.optimizer.update_model(self.model_weights, optim_host_gradient)
batch_index += 1
self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
LOGGER.info("Get is_converged flag from arbiter:{}".format(self.is_converged))
if self.validation_strategy:
LOGGER.debug('LinR host running validation')
self.validation_strategy.validate(self, self.n_iter_)
if self.validation_strategy.need_stop():
LOGGER.debug('early stopping triggered')
break
self.n_iter_ += 1
LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))
if self.is_converged:
break
if not self.is_converged:
LOGGER.info("Reach max iter {}, train model finish!".format(self.max_iter))
if self.validation_strategy and self.validation_strategy.has_saved_best_model():
self.load_model(self.validation_strategy.cur_best_model)
self.set_summary(self.get_model_summary())
# LOGGER.debug(f"summary content is: {self.summary()}")
def predict(self, data_instances):
"""
Prediction of linR
Parameters
----------
data_instances:DTable of Instance, input data
"""
self.transfer_variable.host_partial_prediction.disable_auto_clean()
LOGGER.info("Start predict ...")
self._abnormal_detection(data_instances)
data_instances = self.align_data_header(data_instances, self.header)
data_features = self.transform(data_instances)
pred_host = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
self.transfer_variable.host_partial_prediction.remote(pred_host, role=consts.GUEST, idx=0)
LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))
LOGGER.info("Remote partial prediction to Guest")