searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

多方隐私计算纵向线性回归算法详细过程分析

2023-03-28 10:24:43
30
0

多方隐私计算纵向线性回归算法流程描述:核心流程包括样本对齐、模型训练、数据预测

1. 样本对齐:样本对齐使用基于RSA算法的安全求交,其流程及交互消息如下图所示。

   

源码:

Host侧:源码中添加日志

RsaIntersectionHost->run()
 L175:
  LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
 L181:
  LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
 L184:
  LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
 L191:
  LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
 L196:
  LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
 L201:
  LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
 L205:
  LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
 L212:
  LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
 L215:
  LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))

Host侧:添加后的源码

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import hashlib

from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.secureprotol.encrypt import RsaEncrypt
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable


class RsaIntersectionHost(RsaIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.transfer_variable = RsaIntersectTransferVariable()

        self.e = None
        self.d = None
        self.n = None

        # parameter for intersection cache
        self.is_version_match = False
        self.has_cache_version = True

    def cal_host_ids_process_pair(self, data_instances):
        return data_instances.map(
            lambda k, v: (
                RsaIntersectionHost.hash(gmpy_math.powmod(int(RsaIntersectionHost.hash(k), 16), self.d, self.n)), k)
        )

    def generate_rsa_key(self, rsa_bit=1024):
        encrypt_operator = RsaEncrypt()
        encrypt_operator.generate_key(rsa_bit)
        return encrypt_operator.get_key_pair()

    def get_rsa_key(self):
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Using intersection cache scheme, start to getting rsa key from cache.")
            rsa_key = cache_utils.get_rsa_of_current_version(host_party_id=self.host_party_id,
                                                             id_type=self.intersect_cache_param.id_type,
                                                             encrypt_type=self.intersect_cache_param.encrypt_type,
                                                             tag='Za')
            if rsa_key is not None:
                e = int(rsa_key.get('rsa_e'))
                d = int(rsa_key.get('rsa_d'))
                n = int(rsa_key.get('rsa_n'))
            else:
                self.has_cache_version = False
                LOGGER.info("Use cache but can not find any version in cache, set has_cache_version to false")
                LOGGER.info("Start to generate rsa key")
                e, d, n = self.generate_rsa_key()
        else:
            LOGGER.info("Not use cache, generate rsa keys.")
            e, d, n = self.generate_rsa_key()
        return e, d, n

    def store_cache(self, host_id, rsa_key: dict, assign_version=None, assign_namespace=None):
        store_cache_ret = cache_utils.store_cache(dtable=host_id,
                                                  guest_party_id=None,
                                                  host_party_id=self.host_party_id,
                                                  version=assign_version,
                                                  id_type=self.intersect_cache_param.id_type,
                                                  encrypt_type=self.intersect_cache_param.encrypt_type,
                                                  tag='Za',
                                                  namespace=assign_namespace)
        LOGGER.info("Finish store host_ids_process to cache")

        version = store_cache_ret.get('table_name')
        namespace = store_cache_ret.get('namespace')
        cache_utils.store_rsa(host_party_id=self.host_party_id,
                              id_type=self.intersect_cache_param.id_type,
                              encrypt_type=self.intersect_cache_param.encrypt_type,
                              tag='Za',
                              namespace=namespace,
                              version=version,
                              rsa=rsa_key
                              )
        LOGGER.info("Finish store rsa key to cache")

        return version, namespace

    def host_ids_process(self, data_instances):
        # (host_id_process, 1)
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Use intersect cache.")
            if self.has_cache_version:
                current_version = cache_utils.host_get_current_verison(host_party_id=self.host_party_id,
                                                                       id_type=self.intersect_cache_param.id_type,
                                                                       encrypt_type=self.intersect_cache_param.encrypt_type,
                                                                       tag='Za')
                version = current_version.get('table_name')
                namespace = current_version.get('namespace')
                guest_current_version = self.transfer_variable.cache_version_info.get(0)
                LOGGER.info("current_version:{}".format(current_version))
                LOGGER.info("guest_current_version:{}".format(guest_current_version))

                if guest_current_version.get('table_name') == version \
                        and guest_current_version.get('namespace') == namespace and \
                        current_version is not None:
                    self.is_version_match = True
                else:
                    self.is_version_match = False

                version_match_info = {'version_match': self.is_version_match,
                                      'version': version,
                                      'namespace': namespace}
                self.transfer_variable.cache_version_match_info.remote(version_match_info,
                                                                       role=consts.GUEST,
                                                                       idx=0)

                host_ids_process_pair = None
                if not self.is_version_match or self.sync_intersect_ids:
                    # if self.sync_intersect_ids is true, host will get the encrypted intersect id from guest,
                    # which need the Za to decrypt them
                    LOGGER.info("read Za from cache")
                    host_ids_process_pair = session.table(name=version,
                                                          namespace=namespace,
                                                          create_if_missing=True,
                                                          error_if_exist=False)
                    if host_ids_process_pair.count() == 0:
                        host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
                        rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
                        self.store_cache(host_ids_process_pair, rsa_key=rsa_key)
            else:
                self.is_version_match = False
                LOGGER.info("is version_match:{}".format(self.is_version_match))
                namespace = cache_utils.gen_cache_namespace(id_type=self.intersect_cache_param.id_type,
                                                            encrypt_type=self.intersect_cache_param.encrypt_type,
                                                            tag='Za',
                                                            host_party_id=self.host_party_id)
                version = cache_utils.gen_cache_version(namespace=namespace,
                                                        create=True)
                version_match_info = {'version_match': self.is_version_match,
                                      'version': version,
                                      'namespace': namespace}
                self.transfer_variable.cache_version_match_info.remote(version_match_info,
                                                                       role=consts.GUEST,
                                                                       idx=0)

                host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
                rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
                self.store_cache(host_ids_process_pair, rsa_key=rsa_key, assign_version=version, assign_namespace=namespace)

            LOGGER.info("remote version match info to guest")
        else:
            LOGGER.info("Not using cache, calculate Za using raw id")
            host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)

        return host_ids_process_pair

    def run(self, data_instances):
        LOGGER.info("Start rsa intersection")
        self.e, self.d, self.n = self.get_rsa_key()
        LOGGER.info("Get rsa key!")
        public_key = {"e": self.e, "n": self.n}

        LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
        self.transfer_variable.rsa_pubkey.remote(public_key,
                                                 role=consts.GUEST,
                                                 idx=0)
        LOGGER.info("Remote public key to Guest.")
        host_ids_process_pair = self.host_ids_process(data_instances)
        LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
        if self.intersect_cache_param.use_cache and not self.is_version_match or not self.intersect_cache_param.use_cache:
            host_ids_process = host_ids_process_pair.mapValues(lambda v: 1)
            LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
            self.transfer_variable.intersect_host_ids_process.remote(host_ids_process,
                                                                     role=consts.GUEST,
                                                                     idx=0)
            LOGGER.info("Remote host_ids_process to Guest.")

        # Recv guest ids
        LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
        guest_ids = self.transfer_variable.intersect_guest_ids.get(idx=0)
        LOGGER.info("Get guest_ids from guest")

        # Process guest ids and return to guest
        LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
        guest_ids_process = guest_ids.map(lambda k, v: (k, gmpy_math.powmod(int(k), self.d, self.n)))
        self.transfer_variable.intersect_guest_ids_process.remote(guest_ids_process,
                                                                  role=consts.GUEST,
                                                                  idx=0)
        LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
        LOGGER.info("Remote guest_ids_process to Guest.")

        # recv intersect ids
        LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
        intersect_ids = None
        if self.sync_intersect_ids:
            encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
            intersect_ids_pair = encrypt_intersect_ids.join(host_ids_process_pair, lambda e, h: h)
            intersect_ids = intersect_ids_pair.map(lambda k, v: (v, "id"))
            LOGGER.info("Get intersect ids from Guest")
            LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
            if not self.only_output_key:
                intersect_ids = self._get_value_from_data(intersect_ids, data_instances)
        LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
        return intersect_ids


class RawIntersectionHost(RawIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.join_role = intersect_params.join_role
        self.role = consts.HOST

    def run(self, data_instances):
        LOGGER.info("Start raw intersection")

        if self.join_role == consts.GUEST:
            intersect_ids = self.intersect_send_id(data_instances)
        elif self.join_role == consts.HOST:
            intersect_ids = self.intersect_join_id(data_instances)
        else:
            raise ValueError("Unknown intersect join role, please check the configure of host")

        return intersect_ids

Guest侧:源码中添加日志

RsaIntersectionGuest->run()
 L142:
  LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
 L149:
  LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
 L157:
  LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
 L161:
  LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
 L167:
  LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
 L173:
  LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
 L183:
  LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
 L206:
  LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
 L209:
  LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))

Guest侧:添加后的源码

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from collections import Iterable

import gmpy2
import random

from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable


class RsaIntersectionGuest(RsaIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)

        self.random_bit = intersect_params.random_bit

        self.e = None
        self.n = None
        self.transfer_variable = RsaIntersectTransferVariable()

        # parameter for intersection cache
        self.intersect_cache_param = intersect_params.intersect_cache_param

    def map_raw_id_to_encrypt_id(self, raw_id_data, encrypt_id_data):
        encrypt_id_data_exchange_kv = encrypt_id_data.map(lambda k, v: (v, k))
        encrypt_raw_id = raw_id_data.join(encrypt_id_data_exchange_kv, lambda r, e: e)
        encrypt_common_id = encrypt_raw_id.map(lambda k, v: (v, "id"))

        return encrypt_common_id

    def get_cache_version_match_info(self):
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Use cache is true")
            # check local cache version for each host
            for i, host_party_id in enumerate(self.host_party_id_list):
                current_version = cache_utils.guest_get_current_version(host_party_id=host_party_id,
                                                                        guest_party_id=None,
                                                                        id_type=self.intersect_cache_param.id_type,
                                                                        encrypt_type=self.intersect_cache_param.encrypt_type,
                                                                        tag='Za'
                                                                        )
                LOGGER.info("host_id:{}, current_version:{}".format(host_party_id, current_version))
                if current_version is None:
                    current_version = {"table_name": None, "namespace": None}

                self.transfer_variable.cache_version_info.remote(current_version,
                                                                 role=consts.HOST,
                                                                 idx=i)
                LOGGER.info("Remote current version to host:{}".format(host_party_id))

            cache_version_match_info = self.transfer_variable.cache_version_match_info.get(idx=-1)
            LOGGER.info("Get cache version match info:{}".format(cache_version_match_info))
        else:
            cache_version_match_info = None
            LOGGER.info("Not using cache, cache_version_match_info is None")

        return cache_version_match_info

    def get_host_id_process(self, cache_version_match_info):
        if self.intersect_cache_param.use_cache:
            host_ids_process_list = [None for _ in self.host_party_id_list]

            if isinstance(cache_version_match_info, Iterable):
                for i, version_info in enumerate(cache_version_match_info):
                    if version_info.get('version_match'):
                        host_ids_process_list[i] = session.table(name=version_info.get('version'),
                                                                 namespace=version_info.get('namespace'),
                                                                 create_if_missing=True,
                                                                 error_if_exist=False)
                        LOGGER.info("Read host {} 's host_ids_process from cache".format(self.host_party_id_list[i]))
            else:
                LOGGER.info("cache_version_match_info is not iterable, not use cache_version_match_info")

            # check which host_id_process is not receive yet
            host_ids_process_rev_idx_list = [ i for i, e in enumerate(host_ids_process_list) if e is None ]

            if len(host_ids_process_rev_idx_list) > 0:
                # Recv host_ids_process
                # table(host_id_process, 1)
                host_ids_process = []
                for rev_idx in host_ids_process_rev_idx_list:
                    host_ids_process.append(self.transfer_variable.intersect_host_ids_process.get(idx=rev_idx))
                    LOGGER.info("Get host_ids_process from host {}".format(self.host_party_id_list[rev_idx]))

                for i, host_idx in enumerate(host_ids_process_rev_idx_list):
                    host_ids_process_list[host_idx] = host_ids_process[i]

                    version = cache_version_match_info[host_idx].get('version')
                    namespace = cache_version_match_info[host_idx].get('namespace')

                    cache_utils.store_cache(dtable=host_ids_process[i],
                                            guest_party_id=self.guest_party_id,
                                            host_party_id=self.host_party_id_list[i],
                                            version=version,
                                            id_type=self.intersect_cache_param.id_type,
                                            encrypt_type=self.intersect_cache_param.encrypt_type,
                                            tag=consts.INTERSECT_CACHE_TAG,
                                            namespace=namespace
                                            )
                    LOGGER.info("Store host {}'s host_ids_process to cache.".format(self.host_party_id_list[host_idx]))
        else:
            host_ids_process_list = self.transfer_variable.intersect_host_ids_process.get(idx=-1)
            LOGGER.info("Not using cache, get host_ids_process from all host")

        return host_ids_process_list

    @staticmethod
    def guest_id_process(sid, random_bit, rsa_e, rsa_n):
        r = random.SystemRandom().getrandbits(random_bit)
        re_hash = gmpy_math.powmod(r, rsa_e, rsa_n) * int(RsaIntersectionGuest.hash(sid), 16) % rsa_n
        return re_hash, (sid, r)

    def run(self, data_instances):
        LOGGER.info("Start rsa intersection")

        public_keys = self.transfer_variable.rsa_pubkey.get(-1)
        LOGGER.info("Get RSA public_key:{} from Host".format(public_keys))
        self.e = [int(public_key["e"]) for public_key in public_keys]
        self.n = [int(public_key["n"]) for public_key in public_keys]

        cache_version_match_info = self.get_cache_version_match_info()
        LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
        # table (r^e % n * hash(sid), sid, r)
        guest_id_process_list = [ data_instances.map(
                lambda k, v: self.guest_id_process(k, random_bit=self.random_bit, rsa_e=self.e[i], rsa_n=self.n[i])) for i in range(len(self.e)) ]

        # table(r^e % n *hash(sid), 1)
        for i, guest_id in enumerate(guest_id_process_list):
            LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
            mask_guest_id = guest_id.mapValues(lambda v: 1)
            self.transfer_variable.intersect_guest_ids.remote(mask_guest_id,
                                                              role=consts.HOST,
                                                              idx=i)
            LOGGER.info("Remote guest_id to Host {}".format(i))

        host_ids_process_list = self.get_host_id_process(cache_version_match_info)
        LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
        LOGGER.info("Get host_ids_process")

        # Recv process guest ids
        LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
        # table(r^e % n *hash(sid), guest_id_process)
        
        recv_guest_ids_process = self.transfer_variable.intersect_guest_ids_process.get(idx=-1)
        LOGGER.info("Get guest_ids_process from Host")

        LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
        # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
        guest_ids_process_final = [v.join(recv_guest_ids_process[i], lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r), int(g[1]), self.n[i]))))
                                   for i, v in enumerate(guest_id_process_list)]

        # table(hash(guest_ids_process/r), sid))
        LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
        sid_guest_ids_process_final = [
            g.map(lambda k, v: (v[1], v[0]))
            for i, g in enumerate(guest_ids_process_final)]

        # intersect table(hash(guest_ids_process/r), sid)
        encrypt_intersect_ids = [v.join(host_ids_process_list[i], lambda sid, h: sid) for i, v in
                                 enumerate(sid_guest_ids_process_final)]

        if len(self.host_party_id_list) > 1:
            LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
            raw_intersect_ids = [e.map(lambda k, v: (v, 1)) for e in encrypt_intersect_ids]
            intersect_ids = self.get_common_intersection(raw_intersect_ids)

            # send intersect id
            if self.sync_intersect_ids:
                for i, host_party_id in enumerate(self.host_party_id_list):
                    remote_intersect_id = self.map_raw_id_to_encrypt_id(intersect_ids, encrypt_intersect_ids[i])
                    self.transfer_variable.intersect_ids.remote(remote_intersect_id,
                                                                role=consts.HOST,
                                                                idx=i)
                    LOGGER.info("Remote intersect ids to Host {}!".format(host_party_id))
            else:
                LOGGER.info("Not send intersect ids to Host!")
        else:
            intersect_ids = encrypt_intersect_ids[0]
            if self.sync_intersect_ids:
                remote_intersect_id = intersect_ids.mapValues(lambda v: 1)
                self.transfer_variable.intersect_ids.remote(remote_intersect_id,
                                                            role=consts.HOST,
                                                            idx=0)

            intersect_ids = intersect_ids.map(lambda k, v: (v, 1))
        LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
        LOGGER.info("Finish intersect_ids computing")

        LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
        if not self.only_output_key:
            intersect_ids = self._get_value_from_data(intersect_ids, data_instances)

        return intersect_ids


class RawIntersectionGuest(RawIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.role = consts.GUEST
        self.join_role = intersect_params.join_role

    def run(self, data_instances):
        LOGGER.info("Start raw intersection")

        if self.join_role == consts.HOST:
            intersect_ids = self.intersect_send_id(data_instances)
        elif self.join_role == consts.GUEST:
            intersect_ids = self.intersect_join_id(data_instances)
        else:
            raise ValueError("Unknown intersect join role, please check the configure of guest")

        return intersect_ids

 

2.模型训练:纵向线性回归算法模型训练中间参数的计算采用Paillier同态加密算法进行保护,并依赖中间方Arbiter进行收敛判断,其流程及交互消息如下图所示。

源码:

./federatedml/linear_model/base_linear_model_arbiter.py
HeteroBaseArbiter->fit()
 L79:
  LOGGER.info("iter:" + str(self.n_iter_))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_base import BaseLinearModel
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator
from federatedml.util.validation_strategy import ValidationStrategy


class HeteroBaseArbiter(BaseLinearModel):
    def __init__(self):
        super(BaseLinearModel, self).__init__()
        self.role = consts.ARBITER

        # attribute
        self.pre_loss = None
        self.loss_history = []
        self.cipher = paillier_cipher.Arbiter()
        self.batch_generator = batch_generator.Arbiter()
        self.gradient_loss_operator = None
        self.converge_procedure = convergence.Arbiter()
        self.best_iteration = -1

    def perform_subtasks(self, **training_info):
        """
        performs any tasks that the arbiter is responsible for.

        This 'perform_subtasks' function serves as a handler on conducting any task that the arbiter is responsible
        for. For example, for the 'perform_subtasks' function of 'HeteroDNNLRArbiter' class located in
        'hetero_dnn_lr_arbiter.py', it performs some works related to updating/training local neural networks of guest
        or host.

        For this particular class, the 'perform_subtasks' function will do nothing. In other words, no subtask is
        performed by this arbiter.

        :param training_info: a dictionary holding training information
        """
        pass

    def init_validation_strategy(self, train_data=None, validate_data=None):
        validation_strategy = ValidationStrategy(self.role, self.mode, self.validation_freqs,
                                                 self.early_stopping_rounds,
                                                 self.use_first_metric_only)
        return validation_strategy

    def fit(self, data_instances=None, validate_data=None):
        """
        Train linear model of role arbiter
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero linear model arbiter fit")

        self.cipher_operator = self.cipher.paillier_keygen(self.model_param.encrypt_param.key_length)
        self.batch_generator.initialize_batch_generator()
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_num)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        while self.n_iter_ < self.max_iter:
        	LOGGER.info("iter:" + str(self.n_iter_))
            iter_loss = None
            batch_data_generator = self.batch_generator.generate_batch_data()
            total_gradient = None
            self.optimizer.set_iters(self.n_iter_)
            for batch_index in batch_data_generator:
                # Compute and Transfer gradient info
                gradient = self.gradient_loss_operator.compute_gradient_procedure(self.cipher_operator,
                                                                                  self.optimizer,
                                                                                  self.n_iter_,
                                                                                  batch_index)
                if total_gradient is None:
                    total_gradient = gradient
                else:
                    total_gradient = total_gradient + gradient
                training_info = {"iteration": self.n_iter_, "batch_index": batch_index}
                self.perform_subtasks(**training_info)

                loss_list = self.gradient_loss_operator.compute_loss(self.cipher_operator, self.n_iter_, batch_index)

                if len(loss_list) == 1:
                    if iter_loss is None:
                        iter_loss = loss_list[0]
                    else:
                        iter_loss += loss_list[0]
                        # LOGGER.info("Get loss from guest:{}".format(de_loss))

            # if converge
            if iter_loss is not None:
                iter_loss /= self.batch_generator.batch_num
                if self.need_call_back_loss:
                    self.callback_loss(self.n_iter_, iter_loss)
                self.loss_history.append(iter_loss)

            if self.model_param.early_stop == 'weight_diff':
                # LOGGER.debug("total_gradient: {}".format(total_gradient))
                weight_diff = fate_operator.norm(total_gradient)
                # LOGGER.info("iter: {}, weight_diff:{}, is_converged: {}".format(self.n_iter_,
                #                                                                 weight_diff, self.is_converged))
                if weight_diff < self.model_param.tol:
                    self.is_converged = True
            else:
                if iter_loss is None:
                    raise ValueError("Multiple host situation, loss early stop function is not available."
                                     "You should use 'weight_diff' instead")
                self.is_converged = self.converge_func.is_converge(iter_loss)
                LOGGER.info("iter: {},  loss:{}, is_converged: {}".format(self.n_iter_, iter_loss, self.is_converged))

            self.converge_procedure.sync_converge_info(self.is_converged, suffix=(self.n_iter_,))

            if self.validation_strategy:
                LOGGER.debug('Linear Arbiter running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    self.best_iteration = self.n_iter_
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        summary = {"loss_history": self.loss_history,
                   "is_converged": self.is_converged,
                   "best_iteration": self.best_iteration}
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        if self.loss_history is not None and len(self.loss_history) > 0:
            summary["best_iter_loss"] = self.loss_history[self.best_iteration]

        self.set_summary(summary)
        LOGGER.debug("finish running linear model arbiter")


./federatedml/framework/hetero/sync/converge_sync.py
 L19~20:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Arbiter->sync_converge_info()
 L29:
  LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))
 L31:
  LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#


from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class Arbiter(object):
    # noinspection PyAttributeOutsideInit
    def _register_convergence(self, is_stopped_transfer):
        self._is_stopped_transfer = is_stopped_transfer

    def sync_converge_info(self, is_converged, suffix=tuple()):
        self._is_stopped_transfer.remote(obj=is_converged, role=consts.HOST, idx=-1, suffix=suffix)
        self._is_stopped_transfer.remote(obj=is_converged, role=consts.GUEST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))

        LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))
class _Client(object):
    # noinspection PyAttributeOutsideInit
    def _register_convergence(self, is_stopped_transfer):
        self._is_stopped_transfer = is_stopped_transfer

    def sync_converge_info(self, suffix=tuple()):
        is_converged = self._is_stopped_transfer.get(idx=0, suffix=suffix)
        return is_converged


Host = _Client
Guest = _Client

./federatedml/framework/hetero/sync/loss_sync.py
 L21~22:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Guest->sync_loss_info()
 L41:LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
Guest->get_host_loss_intermediate()
 L45:LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Guest->get_host_loss_regular()
 L50:
  LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
Host->remote_loss_intermediate()
 L62:
  LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Host->remote_loss_regular()
 L66:
  LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class Arbiter(object):
    def _register_loss_sync(self, loss_transfer):
        self.loss_transfer = loss_transfer

    def sync_loss_info(self, suffix=tuple()):
        loss = self.loss_transfer.get(idx=0, suffix=suffix)
        return loss


class Guest(object):
    def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
        self.host_loss_regular_transfer = host_loss_regular_transfer
        self.loss_transfer = loss_transfer
        self.loss_intermediate_transfer = loss_intermediate_transfer

    def sync_loss_info(self, loss, suffix=tuple()):
        self.loss_transfer.remote(loss, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
        
    def get_host_loss_intermediate(self, suffix=tuple()):
        loss_intermediate = self.loss_intermediate_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
        return loss_intermediate

    def get_host_loss_regular(self, suffix=tuple()):
        losses = self.host_loss_regular_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
        return losses


class Host(object):
    def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
        self.host_loss_regular_transfer = host_loss_regular_transfer
        self.loss_transfer = loss_transfer
        self.loss_intermediate_transfer = loss_intermediate_transfer

    def remote_loss_intermediate(self, loss_intermediate, suffix=tuple()):
        self.loss_intermediate_transfer.remote(obj=loss_intermediate, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))

    def remote_loss_regular(self, loss_regular, suffix=tuple()):
        self.host_loss_regular_transfer.remote(obj=loss_regular, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))

./federatedml/framework/hetero/sync/paillier_keygen_sync.py
 L20~21:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Arbiter->paillier_keygen()
 L32~33:
  priv_key = cipher.get_privacy_key()
  LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
 L35:
  LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
 L37:
  LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#


from federatedml.secureprotol.encrypt import PaillierEncrypt
from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class Arbiter(object):
    # noinspection PyAttributeOutsideInit
    def _register_paillier_keygen(self, pubkey_transfer):
        self._pubkey_transfer = pubkey_transfer

    def paillier_keygen(self, key_length, suffix=tuple()):
        cipher = PaillierEncrypt()
        cipher.generate_key(key_length)
        pub_key = cipher.get_public_key()
        priv_key = cipher.get_privacy_key()
		LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
        self._pubkey_transfer.remote(obj=pub_key, role=consts.HOST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
        self._pubkey_transfer.remote(obj=pub_key, role=consts.GUEST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))
        return cipher


class _Client(object):
    # noinspection PyAttributeOutsideInit
    def _register_paillier_keygen(self, pubkey_transfer):
        self._pubkey_transfer = pubkey_transfer

    def gen_paillier_cipher_operator(self, suffix=tuple()):
        pubkey = self._pubkey_transfer.get(idx=0, suffix=suffix)
        cipher = PaillierEncrypt()
        cipher.set_public_key(pubkey)
        return cipher


Host = _Client
Guest = _Client


./federatedml/secureprotol/fate_paillier.py
 L24~25:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
PaillierPublicKey
 L71~72:
  def __str__(self):
   return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))
PaillierPublicKey->apply_obfuscator()
 L80~83:
  new_ciphertext = (ciphertext * obfuscator) % self.nsquare
  LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
   hex(ciphertext), hex(new_ciphertext), hex(r)))
  return new_ciphertext
PaillierPublicKey->raw_encrypt()
 L100~101:
  LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
   hex(plaintext), hex(ciphertext), self))
PaillierPrivateKey
 L151~153:
  def __str__(self):
   return "PaillierPrivateKey: p={}, q={}".format(
    hex(self.p), hex(self.q))
PaillierPrivateKey->raw_decrypt()
 L191~194:
  result = self.crt(mp, mq)
  LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
   hex(ciphertext), hex(result), self))
  return result
PaillierEncryptedNumber
 L231~233:
  def __str__(self):
   return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
    self.public_key, hex(self.__ciphertext), hex(self.exponent))

"""Paillier encryption library for partially homomorphic encryption."""

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from collections.abc import Mapping
from federatedml.secureprotol.fixedpoint import FixedPointNumber
from federatedml.secureprotol import gmpy_math
import random

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class PaillierKeypair(object):
    def __init__(self):
        pass

    @staticmethod
    def generate_keypair(n_length=1024):
        """return a new :class:`PaillierPublicKey` and :class:`PaillierPrivateKey`.
        """
        p = q = n = None
        n_len = 0

        while n_len != n_length:
            p = gmpy_math.getprimeover(n_length // 2)
            q = p
            while q == p:
                q = gmpy_math.getprimeover(n_length // 2)
            n = p * q
            n_len = n.bit_length()

        public_key = PaillierPublicKey(n)
        private_key = PaillierPrivateKey(public_key, p, q)

        return public_key, private_key


class PaillierPublicKey(object):
    """Contains a public key and associated encryption methods.
    """
    def __init__(self, n):
        self.g = n + 1
        self.n = n
        self.nsquare = n * n
        self.max_int = n // 3 - 1

    def __repr__(self):
        hashcode = hex(hash(self))[2:]
        return "<PaillierPublicKey {}>".format(hashcode[:10])

    def __eq__(self, other):
        return self.n == other.n

    def __hash__(self):
        return hash(self.n)

    def __str__(self):
        return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))

    def apply_obfuscator(self, ciphertext, random_value=None):
        """
        """
        r = random_value or random.SystemRandom().randrange(1, self.n)
        obfuscator = gmpy_math.powmod(r, self.n, self.nsquare)

        new_ciphertext = (ciphertext * obfuscator) % self.nsquare
		LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
			hex(ciphertext), hex(new_ciphertext), hex(r)))
		return new_ciphertext

    def raw_encrypt(self, plaintext, random_value=None):
        """
        """
        if not isinstance(plaintext, int):
            raise TypeError("plaintext should be int, but got: %s" %
                            type(plaintext))

        if plaintext >= (self.n - self.max_int) and plaintext < self.n:
            # Very large plaintext, take a sneaky shortcut using inverses
            neg_plaintext = self.n - plaintext  # = abs(plaintext - nsquare)
            neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare
            ciphertext = gmpy_math.invert(neg_ciphertext, self.nsquare)
        else:
            ciphertext = (self.n * plaintext + 1) % self.nsquare

        LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
			hex(plaintext), hex(ciphertext), self))
        ciphertext = self.apply_obfuscator(ciphertext, random_value)

        return ciphertext

    def encrypt(self, value, precision=None, random_value=None):
        """Encode and Paillier encrypt a real number value.
        """
        encoding = FixedPointNumber.encode(value, self.n, self.max_int, precision)
        obfuscator = random_value or 1
        ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator)
        encryptednumber = PaillierEncryptedNumber(self, ciphertext, encoding.exponent)
        if random_value is None:
            encryptednumber.apply_obfuscator()

        return encryptednumber


class PaillierPrivateKey(object):
    """Contains a private key and associated decryption method.
    """
    def __init__(self, public_key, p, q):
        if not p * q == public_key.n:
            raise ValueError("given public key does not match the given p and q")
        if p == q:
            raise ValueError("p and q have to be different")
        self.public_key = public_key
        if q < p:
            self.p = q
            self.q = p
        else:
            self.p = p
            self.q = q
        self.psquare = self.p * self.p
        self.qsquare = self.q * self.q
        self.q_inverse = gmpy_math.invert(self.q, self.p)
        self.hp = self.h_func(self.p, self.psquare)
        self.hq = self.h_func(self.q, self.qsquare)

    def __eq__(self, other):
        return self.p == other.p and self.q == other.q

    def __hash__(self):
        return hash((self.p, self.q))

    def __repr__(self):
        hashcode = hex(hash(self))[2:]

        return "<PaillierPrivateKey {}>".format(hashcode[:10])

    def __str__(self):
		return "PaillierPrivateKey: p={}, q={}".format(
			hex(self.p), hex(self.q))

    def h_func(self, x, xsquare):
        """Computes the h-function as defined in Paillier's paper page.
        """
        return gmpy_math.invert(self.l_func(gmpy_math.powmod(self.public_key.g,
                                                                 x - 1, xsquare), x), x)

    def l_func(self, x, p):
        """computes the L function as defined in Paillier's paper.
        """

        return (x - 1) // p

    def crt(self, mp, mq):
        """the Chinese Remainder Theorem as needed for decryption.
           return the solution modulo n=pq.
       """
        u = (mp - mq) * self.q_inverse % self.p
        x = (mq + (u * self.q)) % self.public_key.n

        return x

    def raw_decrypt(self, ciphertext):
        """return raw plaintext.
        """
        if not isinstance(ciphertext, int):
            raise TypeError("ciphertext should be an int, not: %s" %
                type(ciphertext))

        mp = self.l_func(gmpy_math.powmod(ciphertext,
                                              self.p-1, self.psquare),
                                              self.p) * self.hp % self.p

        mq = self.l_func(gmpy_math.powmod(ciphertext,
                                              self.q-1, self.qsquare),
                                              self.q) * self.hq % self.q

        result = self.crt(mp, mq)
		LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
			hex(ciphertext), hex(result), self))
		return result

    def decrypt(self, encrypted_number):
        """return the decrypted & decoded plaintext of encrypted_number.
        """
        if not isinstance(encrypted_number, PaillierEncryptedNumber):
            raise TypeError("encrypted_number should be an PaillierEncryptedNumber, \
                             not: %s" % type(encrypted_number))

        if self.public_key != encrypted_number.public_key:
            raise ValueError("encrypted_number was encrypted against a different key!")

        encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
        encoded = FixedPointNumber(encoded,
                             encrypted_number.exponent,
                             self.public_key.n,
                             self.public_key.max_int)
        decrypt_value = encoded.decode()

        return decrypt_value


class PaillierEncryptedNumber(object):
    """Represents the Paillier encryption of a float or int.
    """
    def __init__(self, public_key, ciphertext, exponent=0):
        self.public_key = public_key
        self.__ciphertext = ciphertext
        self.exponent = exponent
        self.__is_obfuscator = False

        if not isinstance(self.__ciphertext, int):
            raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext))

        if not isinstance(self.public_key, PaillierPublicKey):
            raise TypeError("public_key should be a PaillierPublicKey, not: %s" % type(self.public_key))

    def __str__(self):
		return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
			self.public_key, hex(self.__ciphertext), hex(self.exponent))

    def ciphertext(self, be_secure=True):
        """return the ciphertext of the PaillierEncryptedNumber.
        """
        if be_secure and not self.__is_obfuscator:
            self.apply_obfuscator()

        return self.__ciphertext

    def apply_obfuscator(self):
        """ciphertext by multiplying by r ** n with random r
        """
        self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext)
        self.__is_obfuscator = True

    def __add__(self, other):
        if isinstance(other, PaillierEncryptedNumber):
            return self.__add_encryptednumber(other)
        else:
            return self.__add_scalar(other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self + (other * -1)

    def __rsub__(self, other):
        return other + (self * -1)

    def __rmul__(self, scalar):
        return self.__mul__(scalar)

    def __truediv__(self, scalar):
        return self.__mul__(1 / scalar)

    def __mul__(self, scalar):
        """return Multiply by an scalar(such as int, float)
        """

        encode = FixedPointNumber.encode(scalar, self.public_key.n, self.public_key.max_int)
        plaintext = encode.encoding

        if plaintext < 0 or plaintext >= self.public_key.n:
            raise ValueError("Scalar out of bounds: %i" % plaintext)

        if plaintext >= self.public_key.n - self.public_key.max_int:
            # Very large plaintext, play a sneaky trick using inverses
            neg_c = gmpy_math.invert(self.ciphertext(False), self.public_key.nsquare)
            neg_scalar = self.public_key.n - plaintext
            ciphertext = gmpy_math.powmod(neg_c, neg_scalar, self.public_key.nsquare)
        else:
            ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.nsquare)

        exponent = self.exponent + encode.exponent

        return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)

    def increase_exponent_to(self, new_exponent):
        """return PaillierEncryptedNumber:
           new PaillierEncryptedNumber with same value but having great exponent.
        """
        if new_exponent < self.exponent:
            raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent))

        factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent)
        new_encryptednumber = self.__mul__(factor)
        new_encryptednumber.exponent = new_exponent

        return new_encryptednumber

    def __align_exponent(self, x, y):
        """return x,y with same exponet
        """
        if x.exponent < y.exponent:
            x = x.increase_exponent_to(y.exponent)
        elif x.exponent > y.exponent:
            y = y.increase_exponent_to(x.exponent)

        return x, y

    def __add_scalar(self, scalar):
        """return PaillierEncryptedNumber: z = E(x) + y
        """
        encoded = FixedPointNumber.encode(scalar,
                                          self.public_key.n,
                                          self.public_key.max_int,
                                          max_exponent=self.exponent)

        return self.__add_fixpointnumber(encoded)

    def __add_fixpointnumber(self, encoded):
        """return PaillierEncryptedNumber: z = E(x) + FixedPointNumber(y)
        """
        if self.public_key.n != encoded.n:
            raise ValueError("Attempted to add numbers encoded against different public keys!")

        # their exponents must match, and align.
        x, y = self.__align_exponent(self, encoded)

        encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1)
        encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent)

        return encryptednumber

    def __add_encryptednumber(self, other):
        """return PaillierEncryptedNumber: z = E(x) + E(y)
        """
        if self.public_key != other.public_key:
            raise ValueError("add two numbers have different public key!")

        # their exponents must match, and align.
        x, y = self.__align_exponent(self, other)

        encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent)

        return encryptednumber

    def __raw_add(self, e_x, e_y, exponent):
        """return the integer E(x + y) given ints E(x) and E(y).
        """
        ciphertext =  e_x * e_y % self.public_key.nsquare

        return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)


./federatedml/secureprotol/fixedpoint.py
 L22~23:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
FixedPointNumber->encode()
 L86~88:
  result = cls(int_fixpoint % n, exponent, n, max_int)
  LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
  return result
FixedPointNumber->decode()
 L105~107:
  result = mantissa * pow(self.BASE, -self.exponent)
  LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
  return result
FixedPointNumber
 L271~273:
  def __str__(self):
   return "FixedPointNumber: encoding={}, exponent={}".format(
    hex(self.encoding), hex(self.exponent))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import math
import sys

import numpy as np

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class FixedPointNumber(object):
    """Represents a float or int fixedpoit encoding;.
    """
    BASE = 16
    LOG2_BASE = math.log(BASE, 2)
    FLOAT_MANTISSA_BITS = sys.float_info.mant_dig

    Q = 293973345475167247070445277780365744413

    def __init__(self, encoding, exponent, n=None, max_int=None):
        self.n = n
        self.max_int = max_int

        if self.n is None:
            self.n = self.Q
            self.max_int = self.Q // 3 - 1

        self.encoding = encoding
        self.exponent = exponent

    @classmethod
    def encode(cls, scalar, n=None, max_int=None, precision=None, max_exponent=None):
        """return an encoding of an int or float.
        """
        # Calculate the maximum exponent for desired precision
        exponent = None

        #  Too low value preprocess;
        #  avoid "OverflowError: int too large to convert to float"

        if np.abs(scalar) < 1e-200:
            scalar = 0

        if n is None:
            n = cls.Q
            max_int = cls.Q // 3 - 1

        if precision is None:
            if isinstance(scalar, int) or isinstance(scalar, np.int16) or \
                    isinstance(scalar, np.int32) or isinstance(scalar, np.int64):
                exponent = 0
            elif isinstance(scalar, float) or isinstance(scalar, np.float16) \
                    or isinstance(scalar, np.float32) or isinstance(scalar, np.float64):
                flt_exponent = math.frexp(scalar)[1]
                lsb_exponent = cls.FLOAT_MANTISSA_BITS - flt_exponent
                exponent = math.floor(lsb_exponent / cls.LOG2_BASE)
            else:
                raise TypeError("Don't know the precision of type %s."
                                % type(scalar))
        else:
            exponent = math.floor(math.log(precision, cls.BASE))

        if max_exponent is not None:
            exponent = max(max_exponent, exponent)

        int_fixpoint = int(round(scalar * pow(cls.BASE, exponent)))

        if abs(int_fixpoint) > max_int:
            raise ValueError('Integer needs to be within +/- %d but got %d'
                             % (max_int, int_fixpoint))

        result = cls(int_fixpoint % n, exponent, n, max_int)
		LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
		return result

    def decode(self):
        """return decode plaintext.
        """
        if self.encoding >= self.n:
            # Should be mod n
            raise ValueError('Attempted to decode corrupted number')
        elif self.encoding <= self.max_int:
            # Positive
            mantissa = self.encoding
        elif self.encoding >= self.n - self.max_int:
            # Negative
            mantissa = self.encoding - self.n
        else:
            raise OverflowError('Overflow detected in decode number')

        result = mantissa * pow(self.BASE, -self.exponent)
		LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
		return result

    def increase_exponent_to(self, new_exponent):
        """return FixedPointNumber: new encoding with same value but having great exponent.
        """
        if new_exponent < self.exponent:
            raise ValueError('New exponent %i should be greater than'
                             'old exponent %i' % (new_exponent, self.exponent))

        factor = pow(self.BASE, new_exponent - self.exponent)
        new_encoding = self.encoding * factor % self.n

        return FixedPointNumber(new_encoding, new_exponent, self.n, self.max_int)

    def __align_exponent(self, x, y):
        """return x,y with same exponet
        """
        if x.exponent < y.exponent:
            x = x.increase_exponent_to(y.exponent)
        elif x.exponent > y.exponent:
            y = y.increase_exponent_to(x.exponent)

        return x, y

    def __truncate(self, a):
        scalar = a.decode()
        return FixedPointNumber.encode(scalar)

    def __add__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__add_fixpointnumber(other)
        else:
            return self.__add_scalar(other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__sub_fixpointnumber(other)
        else:
            return self.__sub_scalar(other)

    def __rsub__(self, other):
        x = self.__sub__(other)
        x = -1 * x.decode()
        return self.encode(x)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __mul__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__mul_fixpointnumber(other)
        else:
            return self.__mul_scalar(other)

    def __truediv__(self, other):
        if isinstance(other, FixedPointNumber):
            scalar = other.decode()
        else:
            scalar = other

        return self.__mul__(1 / scalar)

    def __rtruediv__(self, other):
        res = 1.0 / self.__truediv__(other).decode()
        return FixedPointNumber.encode(res)

    def __lt__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x < y:
            return True
        else:
            return False

    def __gt__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x > y:
            return True
        else:
            return False

    def __le__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x <= y:
            return True
        else:
            return False

    def __ge__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other

        if x >= y:
            return True
        else:
            return False

    def __eq__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x == y:
            return True
        else:
            return False

    def __ne__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x != y:
            return True
        else:
            return False

    def __add_fixpointnumber(self, other):
        x, y = self.__align_exponent(self, other)
        encoding = (x.encoding + y.encoding) % self.Q
        return FixedPointNumber(encoding, x.exponent)

    def __add_scalar(self, scalar):
        encoded = self.encode(scalar)
        return self.__add_fixpointnumber(encoded)

    def __sub_fixpointnumber(self, other):
        scalar = -1 * other.decode()
        return self.__add_scalar(scalar)

    def __sub_scalar(self, scalar):
        scalar = -1 * scalar
        return self.__add_scalar(scalar)

    def __mul_fixpointnumber(self, other):
        encoding = (self.encoding * other.encoding) % self.Q
        exponet = self.exponent + other.exponent
        mul_fixedpoint = FixedPointNumber(encoding, exponet)
        truncate_mul_fixedpoint = self.__truncate(mul_fixedpoint)
        return truncate_mul_fixedpoint

    def __mul_scalar(self, scalar):
        encoded = self.encode(scalar)
        return self.__mul_fixpointnumber(encoded)
     
    def __str__(self):
        return "FixedPointNumber: encoding={}, exponent={}".format(
            hex(self.encoding), hex(self.exponent))


./federatedml/optim/gradient/hetero_linear_model_gradient.py
Guest->compute_gradient_procedure()
 L178:
  LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L181:
  LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Guest->get_host_forward()
 L188:
  LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Guest->remote_fore_gradient()
 L193:
  LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->update_gradient()
 L197:
  LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L199:
  LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
Host->compute_gradient_procedure()
 L234:
  LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
 L236:
  LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))
 L244:
  LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L247:
  LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Host->remote_host_forward()
 L278:
  LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Host->get_fore_gradient()
 L282:
  LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
Host->update_gradient()
 L287:
  LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L289:
  LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import functools

import numpy as np
import scipy.sparse as sp

from federatedml.feature.sparse_vector import SparseVector
from federatedml.statistic import data_overview
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator


def __compute_partition_gradient(data, fit_intercept=True, is_sparse=False):
    """
    Compute hetero regression gradient for:
    gradient = ∑d*x, where d is fore_gradient which differ from different algorithm
    Parameters
    ----------
    data: DTable, include fore_gradient and features
    fit_intercept: bool, if model has interception or not. Default True

    Returns
    ----------
    numpy.ndarray
        hetero regression model gradient
    """
    feature = []
    fore_gradient = []

    if is_sparse:
        row_indice = []
        col_indice = []
        data_value = []

        row = 0
        feature_shape = None
        for key, (sparse_features, d) in data:
            fore_gradient.append(d)
            assert isinstance(sparse_features, SparseVector)
            if feature_shape is None:
                feature_shape = sparse_features.get_shape()
            for idx, v in sparse_features.get_all_data():
                col_indice.append(idx)
                row_indice.append(row)
                data_value.append(v)
            row += 1
        if feature_shape is None or feature_shape == 0:
            return 0
        sparse_matrix = sp.csr_matrix((data_value, (row_indice, col_indice)), shape=(row, feature_shape))
        fore_gradient = np.array(fore_gradient)

        # gradient = sparse_matrix.transpose().dot(fore_gradient).tolist()
        gradient = fate_operator.dot(sparse_matrix.transpose(), fore_gradient).tolist()
        if fit_intercept:
            bias_grad = np.sum(fore_gradient)
            gradient.append(bias_grad)
            # LOGGER.debug("In first method, gradient: {}, bias_grad: {}".format(gradient, bias_grad))
        return np.array(gradient)

    else:
        for key, value in data:
            feature.append(value[0])
            fore_gradient.append(value[1])
        feature = np.array(feature)
        fore_gradient = np.array(fore_gradient)
        if feature.shape[0] <= 0:
            return 0

        gradient = fate_operator.dot(feature.transpose(), fore_gradient)
        gradient = gradient.tolist()
        if fit_intercept:
            bias_grad = np.sum(fore_gradient)
            gradient.append(bias_grad)
        return np.array(gradient)


def compute_gradient(data_instances, fore_gradient, fit_intercept):
    """
    Compute hetero-regression gradient
    Parameters
    ----------
    data_instances: DTable, input data
    fore_gradient: DTable, fore_gradient
    fit_intercept: bool, if model has intercept or not

    Returns
    ----------
    DTable
        the hetero regression model's gradient
    """
    feat_join_grad = data_instances.join(fore_gradient,
                                         lambda d, g: (d.features, g))
    is_sparse = data_overview.is_sparse_data(data_instances)
    f = functools.partial(__compute_partition_gradient,
                          fit_intercept=fit_intercept,
                          is_sparse=is_sparse)
    gradient_partition = feat_join_grad.applyPartitions(f)
    gradient_partition = gradient_partition.reduce(lambda x, y: x + y)

    gradient = gradient_partition / data_instances.count()

    return gradient


class HeteroGradientBase(object):
    def compute_gradient_procedure(self, *args):
        raise NotImplementedError("Should not call here")

    def set_total_batch_nums(self, total_batch_nums):
        """	
        Use for sqn gradient.	
        """
        pass


class Guest(HeteroGradientBase):
    def __init__(self):
        self.host_forwards = None
        self.forwards = None
        self.aggregated_forwards = None

    def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
                                guest_gradient_transfer, guest_optim_gradient_transfer):
        self.host_forward_transfer = host_forward_transfer
        self.fore_gradient_transfer = fore_gradient_transfer
        self.unilateral_gradient_transfer = guest_gradient_transfer
        self.unilateral_optim_gradient_transfer = guest_optim_gradient_transfer

    def compute_and_aggregate_forwards(self, data_instances, model_weights,
                                       encrypted_calculator, batch_index, current_suffix, offset=None):
        raise NotImplementedError("Function should not be called here")

    def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights, optimizer,
                                   n_iter_, batch_index, offset=None):
        """
          Linear model gradient procedure
          Step 1: get host forwards which differ from different algorithm
                  For Logistic Regression and Linear Regression: forwards = wx
                  For Poisson Regression, forwards = exp(wx)

          Step 2: Compute self forwards and aggregate host forwards and get d = fore_gradient

          Step 3: Compute unilateral gradient = ∑d*x,

          Step 4: Send unilateral gradients to arbiter and received the optimized and decrypted gradient.
          """
        current_suffix = (n_iter_, batch_index)
        # self.host_forwards = self.get_host_forward(suffix=current_suffix)

        fore_gradient = self.compute_and_aggregate_forwards(data_instances, model_weights, encrypted_calculator,
                                                            batch_index, current_suffix, offset)

        self.remote_fore_gradient(fore_gradient, suffix=current_suffix)

        unilateral_gradient = compute_gradient(data_instances,
                                               fore_gradient,
                                               model_weights.fit_intercept)
        if optimizer is not None:
            unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
        LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
        
        LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        return optimized_gradient, fore_gradient, self.host_forwards

    def get_host_forward(self, suffix=tuple()):
        host_forward = self.host_forward_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
        return host_forward

    def remote_fore_gradient(self, fore_gradient, suffix=tuple()):
        self.fore_gradient_transfer.remote(obj=fore_gradient, role=consts.HOST, idx=-1, suffix=suffix)
        LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))

    def update_gradient(self, unilateral_gradient, suffix=tuple()):
        self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
        return optimized_gradient


class Host(HeteroGradientBase):
    def __init__(self):
        self.forwards = None
        self.fore_gradient = None

    def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
                                host_gradient_transfer, host_optim_gradient_transfer):
        self.host_forward_transfer = host_forward_transfer
        self.fore_gradient_transfer = fore_gradient_transfer
        self.unilateral_gradient_transfer = host_gradient_transfer
        self.unilateral_optim_gradient_transfer = host_optim_gradient_transfer

    def compute_forwards(self, data_instances, model_weights):
        raise NotImplementedError("Function should not be called here")

    def compute_unilateral_gradient(self, data_instances, fore_gradient, model_weights, optimizer):
        raise NotImplementedError("Function should not be called here")

    def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights,
                                   optimizer,
                                   n_iter_, batch_index):
        """
        Linear model gradient procedure
        Step 1: get host forwards which differ from different algorithm
                For Logistic Regression: forwards = wx


        """
        current_suffix = (n_iter_, batch_index)

        self.forwards = self.compute_forwards(data_instances, model_weights)
        LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
        encrypted_forward = encrypted_calculator[batch_index].encrypt(self.forwards)
        LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))

        self.remote_host_forward(encrypted_forward, suffix=current_suffix)
        fore_gradient = self.get_fore_gradient(suffix=current_suffix)

        unilateral_gradient = compute_gradient(data_instances,
                                               fore_gradient,
                                               model_weights.fit_intercept)
        LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        if optimizer is not None:
            unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
        LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
        return optimized_gradient, fore_gradient

    def compute_sqn_forwards(self, data_instances, delta_s, cipher_operator):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
        define forward_hess = ∑(0.25 * x * s)
        """
        sqn_forwards = data_instances.mapValues(
            lambda v: cipher_operator.encrypt(fate_operator.vec_dot(v.features, delta_s.coef_) + delta_s.intercept_))
        # forward_sum = sqn_forwards.reduce(reduce_add)
        return sqn_forwards

    def compute_forward_hess(self, data_instances, delta_s, forward_hess):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
        define forward_hess = (0.25 * x * s)
        """
        hess_vector = compute_gradient(data_instances,
                                       forward_hess,
                                       delta_s.fit_intercept)
        return np.array(hess_vector)

    def remote_host_forward(self, host_forward, suffix=tuple()):
        self.host_forward_transfer.remote(obj=host_forward, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
        
    def get_fore_gradient(self, suffix=tuple()):
        host_forward = self.fore_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
        return host_forward

    def update_gradient(self, unilateral_gradient, suffix=tuple()):
        self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))
        return optimized_gradient


class Arbiter(HeteroGradientBase):
    def __init__(self):
        self.has_multiple_hosts = False

    def _register_gradient_sync(self, guest_gradient_transfer, host_gradient_transfer,
                                guest_optim_gradient_transfer, host_optim_gradient_transfer):
        self.guest_gradient_transfer = guest_gradient_transfer
        self.host_gradient_transfer = host_gradient_transfer
        self.guest_optim_gradient_transfer = guest_optim_gradient_transfer
        self.host_optim_gradient_transfer = host_optim_gradient_transfer

    def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
        """
        Compute gradients.
        Received local_gradients from guest and hosts. Merge and optimize, then separate and remote back.
        Parameters
        ----------
        cipher_operator: Use for encryption

        optimizer: optimizer that get delta gradient of this iter

        n_iter_: int, current iter nums

        batch_index: int, use to obtain current encrypted_calculator

        """
        current_suffix = (n_iter_, batch_index)

        host_gradients, guest_gradient = self.get_local_gradient(current_suffix)

        if len(host_gradients) > 1:
            self.has_multiple_hosts = True

        host_gradients = [np.array(h) for h in host_gradients]
        guest_gradient = np.array(guest_gradient)

        size_list = [h_g.shape[0] for h_g in host_gradients]
        size_list.append(guest_gradient.shape[0])

        gradient = np.hstack((h for h in host_gradients))
        gradient = np.hstack((gradient, guest_gradient))

        grad = np.array(cipher_operator.decrypt_list(gradient))

        # LOGGER.debug("In arbiter compute_gradient_procedure, before apply grad: {}, size_list: {}".format(
        #     grad, size_list
        # ))

        delta_grad = optimizer.apply_gradients(grad)

        # LOGGER.debug("In arbiter compute_gradient_procedure, delta_grad: {}".format(
        #     delta_grad
        # ))
        separate_optim_gradient = self.separate(delta_grad, size_list)
        # LOGGER.debug("In arbiter compute_gradient_procedure, separated gradient: {}".format(
        #     separate_optim_gradient
        # ))
        host_optim_gradients = separate_optim_gradient[: -1]
        guest_optim_gradient = separate_optim_gradient[-1]

        self.remote_local_gradient(host_optim_gradients, guest_optim_gradient, current_suffix)
        return delta_grad

    @staticmethod
    def separate(value, size_list):
        """
        Separate value in order to several set according size_list
        Parameters
        ----------
        value: list or ndarray, input data
        size_list: list, each set size

        Returns
        ----------
        list
            set after separate
        """
        separate_res = []
        cur = 0
        for size in size_list:
            separate_res.append(value[cur:cur + size])
            cur += size
        return separate_res

    def get_local_gradient(self, suffix=tuple()):
        host_gradients = self.host_gradient_transfer.get(idx=-1, suffix=suffix)
        LOGGER.info("Get host_gradient from Host")

        guest_gradient = self.guest_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.info("Get guest_gradient from Guest")
        return host_gradients, guest_gradient

    def remote_local_gradient(self, host_optim_gradients, guest_optim_gradient, suffix=tuple()):
        for idx, host_optim_gradient in enumerate(host_optim_gradients):
            self.host_optim_gradient_transfer.remote(host_optim_gradient,
                                                     role=consts.HOST,
                                                     idx=idx,
                                                     suffix=suffix)

        self.guest_optim_gradient_transfer.remote(guest_optim_gradient,
                                                  role=consts.GUEST,
                                                  idx=0,
                                                  suffix=suffix)


./federatedml/optim/gradient/hetero_linr_gradient_and_loss.py
Guest->compute_and_aggregate_forwards()
 L65:
  LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
 L67:
  LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
 L71:
  LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
 L73:
  LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->compute_loss()
 L98:
  LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
 L100:
  LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
 L103:
  LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
 L105:
  LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
 L108:
  LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
Host->compute_loss()
 L155:
  LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
 L157:
  LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
 L162:
  LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
 L164:
  LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import numpy as np

from federatedml.framework.hetero.sync import loss_sync
from federatedml.optim.gradient import hetero_linear_model_gradient
from federatedml.util import LOGGER
from federatedml.util.fate_operator import reduce_add, vec_dot


class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):

    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.host_forward,
                                     transfer_variables.fore_gradient,
                                     transfer_variables.guest_gradient,
                                     transfer_variables.guest_optim_gradient)

        self._register_loss_sync(transfer_variables.host_loss_regular,
                                 transfer_variables.loss,
                                 transfer_variables.loss_intermediate)

    def compute_and_aggregate_forwards(self, data_instances, model_weights,
                                       encrypted_calculator, batch_index, current_suffix, offset=None):
        """
        Compute gradients:
        gradient = (1/N)*\sum(wx -y)*x

        Define wx as guest_forward or host_forward
        Define (wx-y) as fore_gradient
        Parameters
        ----------
        data_instances: DTable of Instance, input data

        model_weights: LinearRegressionWeights
            Stores coef_ and intercept_ of model

        encrypted_calculator: Use for different encrypted methods

        offset: Used in Poisson only.

        batch_index: int, use to obtain current encrypted_calculator index:

        current_suffix: tuple or string. Used in transfer_variable
        """
        wx = data_instances.mapValues(
            lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
        self.forwards = wx
        LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
        self.aggregated_forwards = encrypted_calculator[batch_index].encrypt(wx)
        LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
        self.host_forwards = self.get_host_forward(suffix=current_suffix)

        for host_forward in self.host_forwards:
            LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
            self.aggregated_forwards = self.aggregated_forwards.join(host_forward, lambda g, h: g + h)
        LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
        fore_gradient = self.aggregated_forwards.join(data_instances, lambda wx, d: wx - d.label)
        return fore_gradient

    def compute_loss(self, data_instances, n_iter_, batch_index, loss_norm=None):
        '''
        Compute hetero linr loss:
            loss = (1/N)*\sum(wx-y)^2 where y is label, w is model weight and x is features
        (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + wx_g - y)
        '''
        current_suffix = (n_iter_, batch_index)
        n = data_instances.count()
        loss_list = []
        host_wx_squares = self.get_host_loss_intermediate(current_suffix)

        if loss_norm is not None:
            host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
        else:
            host_loss_regular = []
        if len(self.host_forwards) > 1:
            LOGGER.info("More than one host exist, loss is not available")
        else:
            host_forward = self.host_forwards[0]
            host_wx_square = host_wx_squares[0]
            wxy = self.forwards.join(data_instances, lambda wx, d: wx - d.label)
            LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
            wxy_square = wxy.mapValues(lambda x: np.square(x)).reduce(reduce_add)
            LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
            loss_gh = wxy.join(host_forward, lambda g, h: g * h).reduce(reduce_add)
            
            LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
            loss = (wxy_square + host_wx_square + 2 * loss_gh) / (2 * n)
            LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
            if loss_norm is not None:
                loss = loss + loss_norm + host_loss_regular[0]
                LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
            loss_list.append(loss)
        LOGGER.debug("In compute_loss, loss list are: {}".format(loss_list))
        self.sync_loss_info(loss_list, suffix=current_suffix)

    def compute_forward_hess(self, data_instances, delta_s, host_forwards):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(wx - y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(x * s) * x
        define forward_hess = (1/N)*∑(x * s)
        """
        forwards = data_instances.mapValues(
            lambda v: (np.dot(v.features, delta_s.coef_) + delta_s.intercept_))
        for host_forward in host_forwards:
            forwards = forwards.join(host_forward, lambda g, h: g + h)
        hess_vector = hetero_linear_model_gradient.compute_gradient(data_instances,
                                                                    forwards,
                                                                    delta_s.fit_intercept)
        return forwards, np.array(hess_vector)


class Host(hetero_linear_model_gradient.Host, loss_sync.Host):

    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.host_forward,
                                     transfer_variables.fore_gradient,
                                     transfer_variables.host_gradient,
                                     transfer_variables.host_optim_gradient)
        self._register_loss_sync(transfer_variables.host_loss_regular,
                                 transfer_variables.loss,
                                 transfer_variables.loss_intermediate)

    def compute_forwards(self, data_instances, model_weights):
        wx = data_instances.mapValues(lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
        return wx

    def compute_loss(self, model_weights, optimizer, n_iter_, batch_index, cipher_operator):
        '''
        Compute htero linr loss for:
            loss = (1/2N)*\sum(wx-y)^2 where y is label, w is model weight and x is features

            Note: (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + (wx_g - y))
        '''

        current_suffix = (n_iter_, batch_index)
        self_wx_square = self.forwards.mapValues(lambda x: np.square(x)).reduce(reduce_add)
        LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
        en_wx_square = cipher_operator.encrypt(self_wx_square)
        LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
        self.remote_loss_intermediate(en_wx_square, suffix=current_suffix)
        loss_regular = optimizer.loss_norm(model_weights)
        
        if loss_regular is not None:
            LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
            en_loss_regular = cipher_operator.encrypt(loss_regular)
            LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))
            self.remote_loss_regular(en_loss_regular, suffix=current_suffix)


class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.guest_gradient,
                                     transfer_variables.host_gradient,
                                     transfer_variables.guest_optim_gradient,
                                     transfer_variables.host_optim_gradient)
        self._register_loss_sync(transfer_variables.loss)

    def compute_loss(self, cipher, n_iter_, batch_index):
        """
        Decrypt loss from guest
        """
        current_suffix = (n_iter_, batch_index)
        loss_list = self.sync_loss_info(suffix=current_suffix)
        de_loss_list = cipher.decrypt_list(loss_list)
        return de_loss_list

 

3.数据预测:使用F纵向线性回归算法模型训练结果进行预测比较简单,中间参数也未采用加密手段进行保护,其流程及交互消息如下图所示。

 

源码:

在源码中添加日志:

./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_guest.py
HeteroLinRGuest->fit()
 L67:
  LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRGuest->predict()
 L148:
  LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
 L150:
  LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
 L157:
  LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util.io_check import assert_io_num_rows_equal


class HeteroLinRGuest(HeteroLinRBase):
    def __init__(self):
        super().__init__()
        self.data_batch_count = []
        # self.guest_forward = None
        self.role = consts.GUEST
        self.cipher = paillier_cipher.Guest()
        self.batch_generator = batch_generator.Guest()
        self.gradient_loss_operator = hetero_linr_gradient_and_loss.Guest()
        self.converge_procedure = convergence.Guest()
        self.encrypted_calculator = None

    @staticmethod
    def load_data(data_instance):
        """
        return data_instance as original
        Parameters
        ----------
        data_instance: DTable of Instance, input data
        """
        return data_instance

    def fit(self, data_instances, validate_data=None):
        """
        Train linR model of role guest
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_linR_guest fit")
        self._abnormal_detection(data_instances)
        self.header = self.get_header(data_instances)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        LOGGER.info("Generate mini-batch from input data")
        LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
        self.batch_generator.initialize_batch_generator(data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)

        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data)

                # Start gradient procedure
                optim_guest_gradient, _, _ = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index
                )

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(data_instances, self.n_iter_, batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient)
                batch_index += 1
                # LOGGER.debug(
                #     "model_weights, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed))

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
            LOGGER.info("iter: {},  is_converged: {}".format(self.n_iter_, self.is_converged))

            # LOGGER.debug("model weights is {}".format(self.model_weights.coef_))

            if self.validation_strategy:
                LOGGER.debug('LinR guest running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())

    @assert_io_num_rows_equal
    def predict(self, data_instances):
        """
        Prediction of linR
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        predict_param: PredictParam, the setting of prediction.

        Returns
        ----------
        DTable
            include input data label, predict results
        """
        LOGGER.info("Start predict ...")
        self._abnormal_detection(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        data_features = self.transform(data_instances)
        pred = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
        LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
        host_preds = self.transfer_variable.host_partial_prediction.get(idx=-1)
        LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
        LOGGER.info("Get prediction from Host")

        for host_pred in host_preds:
            pred = pred.join(host_pred, lambda g, h: g + h)
        # predict_result = data_instances.join(pred, lambda d, pred: [d.label, pred, pred, {"label": pred}])
        predict_result = self.predict_score_to_output(data_instances=data_instances, predict_score=pred,
                                                      classes=None)
        LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))
        return predict_result

./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_host.py
HeteroLinRHost->fit()
 L58:
  LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRHost->predict()
 L131:
  LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
 L133:
  LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts


class HeteroLinRHost(HeteroLinRBase):
    def __init__(self):
        super(HeteroLinRHost, self).__init__()
        self.batch_num = None
        self.batch_index_list = []
        self.role = consts.HOST

        self.cipher = paillier_cipher.Host()
        self.batch_generator = batch_generator.Host()
        self.gradient_loss_operator = hetero_linr_gradient_and_loss.Host()
        self.converge_procedure = convergence.Host()
        self.encrypted_calculator = None

    def fit(self, data_instances, validate_data=None):
        """
        Train linear regression model of role host
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_linR host")
        self._abnormal_detection(data_instances)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.header = self.get_header(data_instances)
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        self.batch_generator.initialize_batch_generator(data_instances)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)
        LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)
        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            self.optimizer.set_iters(self.n_iter_)
            batch_data_generator = self.batch_generator.generate_batch_data()
            batch_index = 0
            for batch_data in batch_data_generator:
                batch_feat_inst = self.transform(batch_data)
                optim_host_gradient, _ = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index)

                self.gradient_loss_operator.compute_loss(self.model_weights, self.optimizer, self.n_iter_, batch_index,
                                                         self.cipher_operator)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_host_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))

            LOGGER.info("Get is_converged flag from arbiter:{}".format(self.is_converged))

            if self.validation_strategy:
                LOGGER.debug('LinR host running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))
            if self.is_converged:
                break
        if not self.is_converged:
            LOGGER.info("Reach max iter {}, train model finish!".format(self.max_iter))
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
        # LOGGER.debug(f"summary content is: {self.summary()}")

    def predict(self, data_instances):
        """
        Prediction of linR
        Parameters
        ----------
        data_instances:DTable of Instance, input data
        """
        self.transfer_variable.host_partial_prediction.disable_auto_clean()

        LOGGER.info("Start predict ...")

        self._abnormal_detection(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        data_features = self.transform(data_instances)

        pred_host = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
        LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
        self.transfer_variable.host_partial_prediction.remote(pred_host, role=consts.GUEST, idx=0)
        LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))
        LOGGER.info("Remote partial prediction to Guest")

 

 

0条评论
0 / 1000
AndyXiong
4文章数
0粉丝数
AndyXiong
4 文章 | 0 粉丝
AndyXiong
4文章数
0粉丝数
AndyXiong
4 文章 | 0 粉丝
原创

多方隐私计算纵向线性回归算法详细过程分析

2023-03-28 10:24:43
30
0

多方隐私计算纵向线性回归算法流程描述:核心流程包括样本对齐、模型训练、数据预测

1. 样本对齐:样本对齐使用基于RSA算法的安全求交,其流程及交互消息如下图所示。

   

源码:

Host侧:源码中添加日志

RsaIntersectionHost->run()
 L175:
  LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
 L181:
  LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
 L184:
  LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
 L191:
  LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
 L196:
  LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
 L201:
  LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
 L205:
  LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
 L212:
  LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
 L215:
  LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))

Host侧:添加后的源码

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import hashlib

from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.secureprotol.encrypt import RsaEncrypt
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable


class RsaIntersectionHost(RsaIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.transfer_variable = RsaIntersectTransferVariable()

        self.e = None
        self.d = None
        self.n = None

        # parameter for intersection cache
        self.is_version_match = False
        self.has_cache_version = True

    def cal_host_ids_process_pair(self, data_instances):
        return data_instances.map(
            lambda k, v: (
                RsaIntersectionHost.hash(gmpy_math.powmod(int(RsaIntersectionHost.hash(k), 16), self.d, self.n)), k)
        )

    def generate_rsa_key(self, rsa_bit=1024):
        encrypt_operator = RsaEncrypt()
        encrypt_operator.generate_key(rsa_bit)
        return encrypt_operator.get_key_pair()

    def get_rsa_key(self):
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Using intersection cache scheme, start to getting rsa key from cache.")
            rsa_key = cache_utils.get_rsa_of_current_version(host_party_id=self.host_party_id,
                                                             id_type=self.intersect_cache_param.id_type,
                                                             encrypt_type=self.intersect_cache_param.encrypt_type,
                                                             tag='Za')
            if rsa_key is not None:
                e = int(rsa_key.get('rsa_e'))
                d = int(rsa_key.get('rsa_d'))
                n = int(rsa_key.get('rsa_n'))
            else:
                self.has_cache_version = False
                LOGGER.info("Use cache but can not find any version in cache, set has_cache_version to false")
                LOGGER.info("Start to generate rsa key")
                e, d, n = self.generate_rsa_key()
        else:
            LOGGER.info("Not use cache, generate rsa keys.")
            e, d, n = self.generate_rsa_key()
        return e, d, n

    def store_cache(self, host_id, rsa_key: dict, assign_version=None, assign_namespace=None):
        store_cache_ret = cache_utils.store_cache(dtable=host_id,
                                                  guest_party_id=None,
                                                  host_party_id=self.host_party_id,
                                                  version=assign_version,
                                                  id_type=self.intersect_cache_param.id_type,
                                                  encrypt_type=self.intersect_cache_param.encrypt_type,
                                                  tag='Za',
                                                  namespace=assign_namespace)
        LOGGER.info("Finish store host_ids_process to cache")

        version = store_cache_ret.get('table_name')
        namespace = store_cache_ret.get('namespace')
        cache_utils.store_rsa(host_party_id=self.host_party_id,
                              id_type=self.intersect_cache_param.id_type,
                              encrypt_type=self.intersect_cache_param.encrypt_type,
                              tag='Za',
                              namespace=namespace,
                              version=version,
                              rsa=rsa_key
                              )
        LOGGER.info("Finish store rsa key to cache")

        return version, namespace

    def host_ids_process(self, data_instances):
        # (host_id_process, 1)
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Use intersect cache.")
            if self.has_cache_version:
                current_version = cache_utils.host_get_current_verison(host_party_id=self.host_party_id,
                                                                       id_type=self.intersect_cache_param.id_type,
                                                                       encrypt_type=self.intersect_cache_param.encrypt_type,
                                                                       tag='Za')
                version = current_version.get('table_name')
                namespace = current_version.get('namespace')
                guest_current_version = self.transfer_variable.cache_version_info.get(0)
                LOGGER.info("current_version:{}".format(current_version))
                LOGGER.info("guest_current_version:{}".format(guest_current_version))

                if guest_current_version.get('table_name') == version \
                        and guest_current_version.get('namespace') == namespace and \
                        current_version is not None:
                    self.is_version_match = True
                else:
                    self.is_version_match = False

                version_match_info = {'version_match': self.is_version_match,
                                      'version': version,
                                      'namespace': namespace}
                self.transfer_variable.cache_version_match_info.remote(version_match_info,
                                                                       role=consts.GUEST,
                                                                       idx=0)

                host_ids_process_pair = None
                if not self.is_version_match or self.sync_intersect_ids:
                    # if self.sync_intersect_ids is true, host will get the encrypted intersect id from guest,
                    # which need the Za to decrypt them
                    LOGGER.info("read Za from cache")
                    host_ids_process_pair = session.table(name=version,
                                                          namespace=namespace,
                                                          create_if_missing=True,
                                                          error_if_exist=False)
                    if host_ids_process_pair.count() == 0:
                        host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
                        rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
                        self.store_cache(host_ids_process_pair, rsa_key=rsa_key)
            else:
                self.is_version_match = False
                LOGGER.info("is version_match:{}".format(self.is_version_match))
                namespace = cache_utils.gen_cache_namespace(id_type=self.intersect_cache_param.id_type,
                                                            encrypt_type=self.intersect_cache_param.encrypt_type,
                                                            tag='Za',
                                                            host_party_id=self.host_party_id)
                version = cache_utils.gen_cache_version(namespace=namespace,
                                                        create=True)
                version_match_info = {'version_match': self.is_version_match,
                                      'version': version,
                                      'namespace': namespace}
                self.transfer_variable.cache_version_match_info.remote(version_match_info,
                                                                       role=consts.GUEST,
                                                                       idx=0)

                host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)
                rsa_key = {'rsa_e': self.e, 'rsa_d': self.d, 'rsa_n': self.n}
                self.store_cache(host_ids_process_pair, rsa_key=rsa_key, assign_version=version, assign_namespace=namespace)

            LOGGER.info("remote version match info to guest")
        else:
            LOGGER.info("Not using cache, calculate Za using raw id")
            host_ids_process_pair = self.cal_host_ids_process_pair(data_instances)

        return host_ids_process_pair

    def run(self, data_instances):
        LOGGER.info("Start rsa intersection")
        self.e, self.d, self.n = self.get_rsa_key()
        LOGGER.info("Get rsa key!")
        public_key = {"e": self.e, "n": self.n}

        LOGGER.debug("Host generated RSA key pair: PK_e={}, PK_n={}, SK_d={}".format(self.e, self.n, self.d))
        self.transfer_variable.rsa_pubkey.remote(public_key,
                                                 role=consts.GUEST,
                                                 idx=0)
        LOGGER.info("Remote public key to Guest.")
        host_ids_process_pair = self.host_ids_process(data_instances)
        LOGGER.debug("Host sent to Guest: PK={}".format(public_key))
        if self.intersect_cache_param.use_cache and not self.is_version_match or not self.intersect_cache_param.use_cache:
            host_ids_process = host_ids_process_pair.mapValues(lambda v: 1)
            LOGGER.debug("Host computed: Z_H={}, type={}".format(host_ids_process_pair, type(host_ids_process_pair)))
            self.transfer_variable.intersect_host_ids_process.remote(host_ids_process,
                                                                     role=consts.GUEST,
                                                                     idx=0)
            LOGGER.info("Remote host_ids_process to Guest.")

        # Recv guest ids
        LOGGER.debug("Host sent to Guest: Z_H={}, type={}".format(host_ids_process, type(host_ids_process)))
        guest_ids = self.transfer_variable.intersect_guest_ids.get(idx=0)
        LOGGER.info("Get guest_ids from guest")

        # Process guest ids and return to guest
        LOGGER.debug("Host received from Guest: Y_G={}, type={}".format(guest_ids, type(guest_ids)))
        guest_ids_process = guest_ids.map(lambda k, v: (k, gmpy_math.powmod(int(k), self.d, self.n)))
        self.transfer_variable.intersect_guest_ids_process.remote(guest_ids_process,
                                                                  role=consts.GUEST,
                                                                  idx=0)
        LOGGER.debug("Host computed: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
        LOGGER.info("Remote guest_ids_process to Guest.")

        # recv intersect ids
        LOGGER.debug("Host sent to Guest: Z_G={}, type={}".format(guest_ids_process, type(guest_ids_process)))
        intersect_ids = None
        if self.sync_intersect_ids:
            encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
            intersect_ids_pair = encrypt_intersect_ids.join(host_ids_process_pair, lambda e, h: h)
            intersect_ids = intersect_ids_pair.map(lambda k, v: (v, "id"))
            LOGGER.info("Get intersect ids from Guest")
            LOGGER.debug("Host received from Guest: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
            if not self.only_output_key:
                intersect_ids = self._get_value_from_data(intersect_ids, data_instances)
        LOGGER.debug("Host computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
        return intersect_ids


class RawIntersectionHost(RawIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.join_role = intersect_params.join_role
        self.role = consts.HOST

    def run(self, data_instances):
        LOGGER.info("Start raw intersection")

        if self.join_role == consts.GUEST:
            intersect_ids = self.intersect_send_id(data_instances)
        elif self.join_role == consts.HOST:
            intersect_ids = self.intersect_join_id(data_instances)
        else:
            raise ValueError("Unknown intersect join role, please check the configure of host")

        return intersect_ids

Guest侧:源码中添加日志

RsaIntersectionGuest->run()
 L142:
  LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
 L149:
  LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
 L157:
  LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
 L161:
  LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
 L167:
  LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
 L173:
  LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
 L183:
  LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
 L206:
  LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
 L209:
  LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))

Guest侧:添加后的源码

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from collections import Iterable

import gmpy2
import random

from fate_arch.session import computing_session as session
from federatedml.secureprotol import gmpy_math
from federatedml.statistic.intersect import RawIntersect
from federatedml.statistic.intersect import RsaIntersect
#from federatedml.statistic.intersect.rsa_cache import cache_utils
from federatedml.util import consts
from federatedml.util import LOGGER
from federatedml.transfer_variable.transfer_class.rsa_intersect_transfer_variable import RsaIntersectTransferVariable


class RsaIntersectionGuest(RsaIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)

        self.random_bit = intersect_params.random_bit

        self.e = None
        self.n = None
        self.transfer_variable = RsaIntersectTransferVariable()

        # parameter for intersection cache
        self.intersect_cache_param = intersect_params.intersect_cache_param

    def map_raw_id_to_encrypt_id(self, raw_id_data, encrypt_id_data):
        encrypt_id_data_exchange_kv = encrypt_id_data.map(lambda k, v: (v, k))
        encrypt_raw_id = raw_id_data.join(encrypt_id_data_exchange_kv, lambda r, e: e)
        encrypt_common_id = encrypt_raw_id.map(lambda k, v: (v, "id"))

        return encrypt_common_id

    def get_cache_version_match_info(self):
        if self.intersect_cache_param.use_cache:
            LOGGER.info("Use cache is true")
            # check local cache version for each host
            for i, host_party_id in enumerate(self.host_party_id_list):
                current_version = cache_utils.guest_get_current_version(host_party_id=host_party_id,
                                                                        guest_party_id=None,
                                                                        id_type=self.intersect_cache_param.id_type,
                                                                        encrypt_type=self.intersect_cache_param.encrypt_type,
                                                                        tag='Za'
                                                                        )
                LOGGER.info("host_id:{}, current_version:{}".format(host_party_id, current_version))
                if current_version is None:
                    current_version = {"table_name": None, "namespace": None}

                self.transfer_variable.cache_version_info.remote(current_version,
                                                                 role=consts.HOST,
                                                                 idx=i)
                LOGGER.info("Remote current version to host:{}".format(host_party_id))

            cache_version_match_info = self.transfer_variable.cache_version_match_info.get(idx=-1)
            LOGGER.info("Get cache version match info:{}".format(cache_version_match_info))
        else:
            cache_version_match_info = None
            LOGGER.info("Not using cache, cache_version_match_info is None")

        return cache_version_match_info

    def get_host_id_process(self, cache_version_match_info):
        if self.intersect_cache_param.use_cache:
            host_ids_process_list = [None for _ in self.host_party_id_list]

            if isinstance(cache_version_match_info, Iterable):
                for i, version_info in enumerate(cache_version_match_info):
                    if version_info.get('version_match'):
                        host_ids_process_list[i] = session.table(name=version_info.get('version'),
                                                                 namespace=version_info.get('namespace'),
                                                                 create_if_missing=True,
                                                                 error_if_exist=False)
                        LOGGER.info("Read host {} 's host_ids_process from cache".format(self.host_party_id_list[i]))
            else:
                LOGGER.info("cache_version_match_info is not iterable, not use cache_version_match_info")

            # check which host_id_process is not receive yet
            host_ids_process_rev_idx_list = [ i for i, e in enumerate(host_ids_process_list) if e is None ]

            if len(host_ids_process_rev_idx_list) > 0:
                # Recv host_ids_process
                # table(host_id_process, 1)
                host_ids_process = []
                for rev_idx in host_ids_process_rev_idx_list:
                    host_ids_process.append(self.transfer_variable.intersect_host_ids_process.get(idx=rev_idx))
                    LOGGER.info("Get host_ids_process from host {}".format(self.host_party_id_list[rev_idx]))

                for i, host_idx in enumerate(host_ids_process_rev_idx_list):
                    host_ids_process_list[host_idx] = host_ids_process[i]

                    version = cache_version_match_info[host_idx].get('version')
                    namespace = cache_version_match_info[host_idx].get('namespace')

                    cache_utils.store_cache(dtable=host_ids_process[i],
                                            guest_party_id=self.guest_party_id,
                                            host_party_id=self.host_party_id_list[i],
                                            version=version,
                                            id_type=self.intersect_cache_param.id_type,
                                            encrypt_type=self.intersect_cache_param.encrypt_type,
                                            tag=consts.INTERSECT_CACHE_TAG,
                                            namespace=namespace
                                            )
                    LOGGER.info("Store host {}'s host_ids_process to cache.".format(self.host_party_id_list[host_idx]))
        else:
            host_ids_process_list = self.transfer_variable.intersect_host_ids_process.get(idx=-1)
            LOGGER.info("Not using cache, get host_ids_process from all host")

        return host_ids_process_list

    @staticmethod
    def guest_id_process(sid, random_bit, rsa_e, rsa_n):
        r = random.SystemRandom().getrandbits(random_bit)
        re_hash = gmpy_math.powmod(r, rsa_e, rsa_n) * int(RsaIntersectionGuest.hash(sid), 16) % rsa_n
        return re_hash, (sid, r)

    def run(self, data_instances):
        LOGGER.info("Start rsa intersection")

        public_keys = self.transfer_variable.rsa_pubkey.get(-1)
        LOGGER.info("Get RSA public_key:{} from Host".format(public_keys))
        self.e = [int(public_key["e"]) for public_key in public_keys]
        self.n = [int(public_key["n"]) for public_key in public_keys]

        cache_version_match_info = self.get_cache_version_match_info()
        LOGGER.debug("Guest received from Host: RSA public key PK_e={}, PK_n={}".format(self.e, self.n))
        # table (r^e % n * hash(sid), sid, r)
        guest_id_process_list = [ data_instances.map(
                lambda k, v: self.guest_id_process(k, random_bit=self.random_bit, rsa_e=self.e[i], rsa_n=self.n[i])) for i in range(len(self.e)) ]

        # table(r^e % n *hash(sid), 1)
        for i, guest_id in enumerate(guest_id_process_list):
            LOGGER.debug("Guest computed: Y_G={}, type={}".format(guest_id_process_list, type(guest_id_process_list)))
            mask_guest_id = guest_id.mapValues(lambda v: 1)
            self.transfer_variable.intersect_guest_ids.remote(mask_guest_id,
                                                              role=consts.HOST,
                                                              idx=i)
            LOGGER.info("Remote guest_id to Host {}".format(i))

        host_ids_process_list = self.get_host_id_process(cache_version_match_info)
        LOGGER.debug("Guest sent to Host: Y_G={}, type={}".format(mask_guest_id, type(mask_guest_id)))
        LOGGER.info("Get host_ids_process")

        # Recv process guest ids
        LOGGER.debug("Guest received from Host: Z_H={}, type={}".format(host_ids_process_list, type(host_ids_process_list)))
        # table(r^e % n *hash(sid), guest_id_process)
        
        recv_guest_ids_process = self.transfer_variable.intersect_guest_ids_process.get(idx=-1)
        LOGGER.info("Get guest_ids_process from Host")

        LOGGER.debug("Guest received from Host: Z_G={}, type={}".format(recv_guest_ids_process, type(recv_guest_ids_process)))
        # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
        guest_ids_process_final = [v.join(recv_guest_ids_process[i], lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r), int(g[1]), self.n[i]))))
                                   for i, v in enumerate(guest_id_process_list)]

        # table(hash(guest_ids_process/r), sid))
        LOGGER.debug("Guest computed: D_G={}, type={}".format(guest_ids_process_final, type(guest_ids_process_final)))
        sid_guest_ids_process_final = [
            g.map(lambda k, v: (v[1], v[0]))
            for i, g in enumerate(guest_ids_process_final)]

        # intersect table(hash(guest_ids_process/r), sid)
        encrypt_intersect_ids = [v.join(host_ids_process_list[i], lambda sid, h: sid) for i, v in
                                 enumerate(sid_guest_ids_process_final)]

        if len(self.host_party_id_list) > 1:
            LOGGER.debug("Guest computed: S={}, type={}".format(encrypt_intersect_ids, type(encrypt_intersect_ids)))
            raw_intersect_ids = [e.map(lambda k, v: (v, 1)) for e in encrypt_intersect_ids]
            intersect_ids = self.get_common_intersection(raw_intersect_ids)

            # send intersect id
            if self.sync_intersect_ids:
                for i, host_party_id in enumerate(self.host_party_id_list):
                    remote_intersect_id = self.map_raw_id_to_encrypt_id(intersect_ids, encrypt_intersect_ids[i])
                    self.transfer_variable.intersect_ids.remote(remote_intersect_id,
                                                                role=consts.HOST,
                                                                idx=i)
                    LOGGER.info("Remote intersect ids to Host {}!".format(host_party_id))
            else:
                LOGGER.info("Not send intersect ids to Host!")
        else:
            intersect_ids = encrypt_intersect_ids[0]
            if self.sync_intersect_ids:
                remote_intersect_id = intersect_ids.mapValues(lambda v: 1)
                self.transfer_variable.intersect_ids.remote(remote_intersect_id,
                                                            role=consts.HOST,
                                                            idx=0)

            intersect_ids = intersect_ids.map(lambda k, v: (v, 1))
        LOGGER.debug("Guest sent to Host: S={}, type={}".format(remote_intersect_id, type(remote_intersect_id)))
        LOGGER.info("Finish intersect_ids computing")

        LOGGER.debug("Guest computed: I_id={}, type={}".format(intersect_ids, type(intersect_ids)))
        if not self.only_output_key:
            intersect_ids = self._get_value_from_data(intersect_ids, data_instances)

        return intersect_ids


class RawIntersectionGuest(RawIntersect):
    def __init__(self, intersect_params):
        super().__init__(intersect_params)
        self.role = consts.GUEST
        self.join_role = intersect_params.join_role

    def run(self, data_instances):
        LOGGER.info("Start raw intersection")

        if self.join_role == consts.HOST:
            intersect_ids = self.intersect_send_id(data_instances)
        elif self.join_role == consts.GUEST:
            intersect_ids = self.intersect_join_id(data_instances)
        else:
            raise ValueError("Unknown intersect join role, please check the configure of guest")

        return intersect_ids

 

2.模型训练:纵向线性回归算法模型训练中间参数的计算采用Paillier同态加密算法进行保护,并依赖中间方Arbiter进行收敛判断,其流程及交互消息如下图所示。

源码:

./federatedml/linear_model/base_linear_model_arbiter.py
HeteroBaseArbiter->fit()
 L79:
  LOGGER.info("iter:" + str(self.n_iter_))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_base import BaseLinearModel
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator
from federatedml.util.validation_strategy import ValidationStrategy


class HeteroBaseArbiter(BaseLinearModel):
    def __init__(self):
        super(BaseLinearModel, self).__init__()
        self.role = consts.ARBITER

        # attribute
        self.pre_loss = None
        self.loss_history = []
        self.cipher = paillier_cipher.Arbiter()
        self.batch_generator = batch_generator.Arbiter()
        self.gradient_loss_operator = None
        self.converge_procedure = convergence.Arbiter()
        self.best_iteration = -1

    def perform_subtasks(self, **training_info):
        """
        performs any tasks that the arbiter is responsible for.

        This 'perform_subtasks' function serves as a handler on conducting any task that the arbiter is responsible
        for. For example, for the 'perform_subtasks' function of 'HeteroDNNLRArbiter' class located in
        'hetero_dnn_lr_arbiter.py', it performs some works related to updating/training local neural networks of guest
        or host.

        For this particular class, the 'perform_subtasks' function will do nothing. In other words, no subtask is
        performed by this arbiter.

        :param training_info: a dictionary holding training information
        """
        pass

    def init_validation_strategy(self, train_data=None, validate_data=None):
        validation_strategy = ValidationStrategy(self.role, self.mode, self.validation_freqs,
                                                 self.early_stopping_rounds,
                                                 self.use_first_metric_only)
        return validation_strategy

    def fit(self, data_instances=None, validate_data=None):
        """
        Train linear model of role arbiter
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero linear model arbiter fit")

        self.cipher_operator = self.cipher.paillier_keygen(self.model_param.encrypt_param.key_length)
        self.batch_generator.initialize_batch_generator()
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_num)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        while self.n_iter_ < self.max_iter:
        	LOGGER.info("iter:" + str(self.n_iter_))
            iter_loss = None
            batch_data_generator = self.batch_generator.generate_batch_data()
            total_gradient = None
            self.optimizer.set_iters(self.n_iter_)
            for batch_index in batch_data_generator:
                # Compute and Transfer gradient info
                gradient = self.gradient_loss_operator.compute_gradient_procedure(self.cipher_operator,
                                                                                  self.optimizer,
                                                                                  self.n_iter_,
                                                                                  batch_index)
                if total_gradient is None:
                    total_gradient = gradient
                else:
                    total_gradient = total_gradient + gradient
                training_info = {"iteration": self.n_iter_, "batch_index": batch_index}
                self.perform_subtasks(**training_info)

                loss_list = self.gradient_loss_operator.compute_loss(self.cipher_operator, self.n_iter_, batch_index)

                if len(loss_list) == 1:
                    if iter_loss is None:
                        iter_loss = loss_list[0]
                    else:
                        iter_loss += loss_list[0]
                        # LOGGER.info("Get loss from guest:{}".format(de_loss))

            # if converge
            if iter_loss is not None:
                iter_loss /= self.batch_generator.batch_num
                if self.need_call_back_loss:
                    self.callback_loss(self.n_iter_, iter_loss)
                self.loss_history.append(iter_loss)

            if self.model_param.early_stop == 'weight_diff':
                # LOGGER.debug("total_gradient: {}".format(total_gradient))
                weight_diff = fate_operator.norm(total_gradient)
                # LOGGER.info("iter: {}, weight_diff:{}, is_converged: {}".format(self.n_iter_,
                #                                                                 weight_diff, self.is_converged))
                if weight_diff < self.model_param.tol:
                    self.is_converged = True
            else:
                if iter_loss is None:
                    raise ValueError("Multiple host situation, loss early stop function is not available."
                                     "You should use 'weight_diff' instead")
                self.is_converged = self.converge_func.is_converge(iter_loss)
                LOGGER.info("iter: {},  loss:{}, is_converged: {}".format(self.n_iter_, iter_loss, self.is_converged))

            self.converge_procedure.sync_converge_info(self.is_converged, suffix=(self.n_iter_,))

            if self.validation_strategy:
                LOGGER.debug('Linear Arbiter running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    self.best_iteration = self.n_iter_
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        summary = {"loss_history": self.loss_history,
                   "is_converged": self.is_converged,
                   "best_iteration": self.best_iteration}
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        if self.loss_history is not None and len(self.loss_history) > 0:
            summary["best_iter_loss"] = self.loss_history[self.best_iteration]

        self.set_summary(summary)
        LOGGER.debug("finish running linear model arbiter")


./federatedml/framework/hetero/sync/converge_sync.py
 L19~20:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Arbiter->sync_converge_info()
 L29:
  LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))
 L31:
  LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#


from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()
class Arbiter(object):
    # noinspection PyAttributeOutsideInit
    def _register_convergence(self, is_stopped_transfer):
        self._is_stopped_transfer = is_stopped_transfer

    def sync_converge_info(self, is_converged, suffix=tuple()):
        self._is_stopped_transfer.remote(obj=is_converged, role=consts.HOST, idx=-1, suffix=suffix)
        self._is_stopped_transfer.remote(obj=is_converged, role=consts.GUEST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Host: is_converged={}, type={}".format(is_converged, type(is_converged)))

        LOGGER.debug("Arbiter sent to Guest: is_converged={}, type={}".format(is_converged, type(is_converged)))
class _Client(object):
    # noinspection PyAttributeOutsideInit
    def _register_convergence(self, is_stopped_transfer):
        self._is_stopped_transfer = is_stopped_transfer

    def sync_converge_info(self, suffix=tuple()):
        is_converged = self._is_stopped_transfer.get(idx=0, suffix=suffix)
        return is_converged


Host = _Client
Guest = _Client

./federatedml/framework/hetero/sync/loss_sync.py
 L21~22:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Guest->sync_loss_info()
 L41:LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
Guest->get_host_loss_intermediate()
 L45:LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Guest->get_host_loss_regular()
 L50:
  LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
Host->remote_loss_intermediate()
 L62:
  LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
Host->remote_loss_regular()
 L66:
  LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class Arbiter(object):
    def _register_loss_sync(self, loss_transfer):
        self.loss_transfer = loss_transfer

    def sync_loss_info(self, suffix=tuple()):
        loss = self.loss_transfer.get(idx=0, suffix=suffix)
        return loss


class Guest(object):
    def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
        self.host_loss_regular_transfer = host_loss_regular_transfer
        self.loss_transfer = loss_transfer
        self.loss_intermediate_transfer = loss_intermediate_transfer

    def sync_loss_info(self, loss, suffix=tuple()):
        self.loss_transfer.remote(loss, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Guest sent to Arbiter: [[L]]={}, type={}".format(loss, type(loss)))
        
    def get_host_loss_intermediate(self, suffix=tuple()):
        loss_intermediate = self.loss_intermediate_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))
        return loss_intermediate

    def get_host_loss_regular(self, suffix=tuple()):
        losses = self.host_loss_regular_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[(λ/2)(Θ^H)^2]]={}, type={}".format(losses, type(losses)))
        return losses


class Host(object):
    def _register_loss_sync(self, host_loss_regular_transfer, loss_transfer, loss_intermediate_transfer):
        self.host_loss_regular_transfer = host_loss_regular_transfer
        self.loss_transfer = loss_transfer
        self.loss_intermediate_transfer = loss_intermediate_transfer

    def remote_loss_intermediate(self, loss_intermediate, suffix=tuple()):
        self.loss_intermediate_transfer.remote(obj=loss_intermediate, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[(u^H)^2]]={}, type={}".format(loss_intermediate, type(loss_intermediate)))

    def remote_loss_regular(self, loss_regular, suffix=tuple()):
        self.host_loss_regular_transfer.remote(obj=loss_regular, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[(λ/2)(Θ^H)^2]]={}, type={}".format(loss_regular, type(loss_regular)))

./federatedml/framework/hetero/sync/paillier_keygen_sync.py
 L20~21:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
Arbiter->paillier_keygen()
 L32~33:
  priv_key = cipher.get_privacy_key()
  LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
 L35:
  LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
 L37:
  LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#


from federatedml.secureprotol.encrypt import PaillierEncrypt
from federatedml.util import consts

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class Arbiter(object):
    # noinspection PyAttributeOutsideInit
    def _register_paillier_keygen(self, pubkey_transfer):
        self._pubkey_transfer = pubkey_transfer

    def paillier_keygen(self, key_length, suffix=tuple()):
        cipher = PaillierEncrypt()
        cipher.generate_key(key_length)
        pub_key = cipher.get_public_key()
        priv_key = cipher.get_privacy_key()
		LOGGER.debug("Arbiter generated a Paillier key pair: PK={}, SK={}".format(pub_key, priv_key))
        self._pubkey_transfer.remote(obj=pub_key, role=consts.HOST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Host: PK={}".format(pub_key))
        self._pubkey_transfer.remote(obj=pub_key, role=consts.GUEST, idx=-1, suffix=suffix)
        LOGGER.debug("Arbiter sent to Guest: PK={}".format(pub_key))
        return cipher


class _Client(object):
    # noinspection PyAttributeOutsideInit
    def _register_paillier_keygen(self, pubkey_transfer):
        self._pubkey_transfer = pubkey_transfer

    def gen_paillier_cipher_operator(self, suffix=tuple()):
        pubkey = self._pubkey_transfer.get(idx=0, suffix=suffix)
        cipher = PaillierEncrypt()
        cipher.set_public_key(pubkey)
        return cipher


Host = _Client
Guest = _Client


./federatedml/secureprotol/fate_paillier.py
 L24~25:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
PaillierPublicKey
 L71~72:
  def __str__(self):
   return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))
PaillierPublicKey->apply_obfuscator()
 L80~83:
  new_ciphertext = (ciphertext * obfuscator) % self.nsquare
  LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
   hex(ciphertext), hex(new_ciphertext), hex(r)))
  return new_ciphertext
PaillierPublicKey->raw_encrypt()
 L100~101:
  LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
   hex(plaintext), hex(ciphertext), self))
PaillierPrivateKey
 L151~153:
  def __str__(self):
   return "PaillierPrivateKey: p={}, q={}".format(
    hex(self.p), hex(self.q))
PaillierPrivateKey->raw_decrypt()
 L191~194:
  result = self.crt(mp, mq)
  LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
   hex(ciphertext), hex(result), self))
  return result
PaillierEncryptedNumber
 L231~233:
  def __str__(self):
   return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
    self.public_key, hex(self.__ciphertext), hex(self.exponent))

"""Paillier encryption library for partially homomorphic encryption."""

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from collections.abc import Mapping
from federatedml.secureprotol.fixedpoint import FixedPointNumber
from federatedml.secureprotol import gmpy_math
import random

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class PaillierKeypair(object):
    def __init__(self):
        pass

    @staticmethod
    def generate_keypair(n_length=1024):
        """return a new :class:`PaillierPublicKey` and :class:`PaillierPrivateKey`.
        """
        p = q = n = None
        n_len = 0

        while n_len != n_length:
            p = gmpy_math.getprimeover(n_length // 2)
            q = p
            while q == p:
                q = gmpy_math.getprimeover(n_length // 2)
            n = p * q
            n_len = n.bit_length()

        public_key = PaillierPublicKey(n)
        private_key = PaillierPrivateKey(public_key, p, q)

        return public_key, private_key


class PaillierPublicKey(object):
    """Contains a public key and associated encryption methods.
    """
    def __init__(self, n):
        self.g = n + 1
        self.n = n
        self.nsquare = n * n
        self.max_int = n // 3 - 1

    def __repr__(self):
        hashcode = hex(hash(self))[2:]
        return "<PaillierPublicKey {}>".format(hashcode[:10])

    def __eq__(self, other):
        return self.n == other.n

    def __hash__(self):
        return hash(self.n)

    def __str__(self):
        return "PaillierPublicKey: g={}, n={}, max_int={}".format(hex(self.g), hex(self.n), hex(self.max_int))

    def apply_obfuscator(self, ciphertext, random_value=None):
        """
        """
        r = random_value or random.SystemRandom().randrange(1, self.n)
        obfuscator = gmpy_math.powmod(r, self.n, self.nsquare)

        new_ciphertext = (ciphertext * obfuscator) % self.nsquare
		LOGGER.debug("PaillierPublicKey->apply_obfuscator(): old_cipher={}, new_cipher={}, random={}".format(
			hex(ciphertext), hex(new_ciphertext), hex(r)))
		return new_ciphertext

    def raw_encrypt(self, plaintext, random_value=None):
        """
        """
        if not isinstance(plaintext, int):
            raise TypeError("plaintext should be int, but got: %s" %
                            type(plaintext))

        if plaintext >= (self.n - self.max_int) and plaintext < self.n:
            # Very large plaintext, take a sneaky shortcut using inverses
            neg_plaintext = self.n - plaintext  # = abs(plaintext - nsquare)
            neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare
            ciphertext = gmpy_math.invert(neg_ciphertext, self.nsquare)
        else:
            ciphertext = (self.n * plaintext + 1) % self.nsquare

        LOGGER.debug("PaillierPublicKey->raw_encrypt(): plaintext={}, ciphertext={}, PK={}".format(
			hex(plaintext), hex(ciphertext), self))
        ciphertext = self.apply_obfuscator(ciphertext, random_value)

        return ciphertext

    def encrypt(self, value, precision=None, random_value=None):
        """Encode and Paillier encrypt a real number value.
        """
        encoding = FixedPointNumber.encode(value, self.n, self.max_int, precision)
        obfuscator = random_value or 1
        ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator)
        encryptednumber = PaillierEncryptedNumber(self, ciphertext, encoding.exponent)
        if random_value is None:
            encryptednumber.apply_obfuscator()

        return encryptednumber


class PaillierPrivateKey(object):
    """Contains a private key and associated decryption method.
    """
    def __init__(self, public_key, p, q):
        if not p * q == public_key.n:
            raise ValueError("given public key does not match the given p and q")
        if p == q:
            raise ValueError("p and q have to be different")
        self.public_key = public_key
        if q < p:
            self.p = q
            self.q = p
        else:
            self.p = p
            self.q = q
        self.psquare = self.p * self.p
        self.qsquare = self.q * self.q
        self.q_inverse = gmpy_math.invert(self.q, self.p)
        self.hp = self.h_func(self.p, self.psquare)
        self.hq = self.h_func(self.q, self.qsquare)

    def __eq__(self, other):
        return self.p == other.p and self.q == other.q

    def __hash__(self):
        return hash((self.p, self.q))

    def __repr__(self):
        hashcode = hex(hash(self))[2:]

        return "<PaillierPrivateKey {}>".format(hashcode[:10])

    def __str__(self):
		return "PaillierPrivateKey: p={}, q={}".format(
			hex(self.p), hex(self.q))

    def h_func(self, x, xsquare):
        """Computes the h-function as defined in Paillier's paper page.
        """
        return gmpy_math.invert(self.l_func(gmpy_math.powmod(self.public_key.g,
                                                                 x - 1, xsquare), x), x)

    def l_func(self, x, p):
        """computes the L function as defined in Paillier's paper.
        """

        return (x - 1) // p

    def crt(self, mp, mq):
        """the Chinese Remainder Theorem as needed for decryption.
           return the solution modulo n=pq.
       """
        u = (mp - mq) * self.q_inverse % self.p
        x = (mq + (u * self.q)) % self.public_key.n

        return x

    def raw_decrypt(self, ciphertext):
        """return raw plaintext.
        """
        if not isinstance(ciphertext, int):
            raise TypeError("ciphertext should be an int, not: %s" %
                type(ciphertext))

        mp = self.l_func(gmpy_math.powmod(ciphertext,
                                              self.p-1, self.psquare),
                                              self.p) * self.hp % self.p

        mq = self.l_func(gmpy_math.powmod(ciphertext,
                                              self.q-1, self.qsquare),
                                              self.q) * self.hq % self.q

        result = self.crt(mp, mq)
		LOGGER.debug("PaillierPrivateKey->raw_decrypt(): ciphertext={}, plaintext={}, SK".format(
			hex(ciphertext), hex(result), self))
		return result

    def decrypt(self, encrypted_number):
        """return the decrypted & decoded plaintext of encrypted_number.
        """
        if not isinstance(encrypted_number, PaillierEncryptedNumber):
            raise TypeError("encrypted_number should be an PaillierEncryptedNumber, \
                             not: %s" % type(encrypted_number))

        if self.public_key != encrypted_number.public_key:
            raise ValueError("encrypted_number was encrypted against a different key!")

        encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
        encoded = FixedPointNumber(encoded,
                             encrypted_number.exponent,
                             self.public_key.n,
                             self.public_key.max_int)
        decrypt_value = encoded.decode()

        return decrypt_value


class PaillierEncryptedNumber(object):
    """Represents the Paillier encryption of a float or int.
    """
    def __init__(self, public_key, ciphertext, exponent=0):
        self.public_key = public_key
        self.__ciphertext = ciphertext
        self.exponent = exponent
        self.__is_obfuscator = False

        if not isinstance(self.__ciphertext, int):
            raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext))

        if not isinstance(self.public_key, PaillierPublicKey):
            raise TypeError("public_key should be a PaillierPublicKey, not: %s" % type(self.public_key))

    def __str__(self):
		return "PaillierEncryptedNumber: public_key={}, ciphertext={}, exponent={}".format(
			self.public_key, hex(self.__ciphertext), hex(self.exponent))

    def ciphertext(self, be_secure=True):
        """return the ciphertext of the PaillierEncryptedNumber.
        """
        if be_secure and not self.__is_obfuscator:
            self.apply_obfuscator()

        return self.__ciphertext

    def apply_obfuscator(self):
        """ciphertext by multiplying by r ** n with random r
        """
        self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext)
        self.__is_obfuscator = True

    def __add__(self, other):
        if isinstance(other, PaillierEncryptedNumber):
            return self.__add_encryptednumber(other)
        else:
            return self.__add_scalar(other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        return self + (other * -1)

    def __rsub__(self, other):
        return other + (self * -1)

    def __rmul__(self, scalar):
        return self.__mul__(scalar)

    def __truediv__(self, scalar):
        return self.__mul__(1 / scalar)

    def __mul__(self, scalar):
        """return Multiply by an scalar(such as int, float)
        """

        encode = FixedPointNumber.encode(scalar, self.public_key.n, self.public_key.max_int)
        plaintext = encode.encoding

        if plaintext < 0 or plaintext >= self.public_key.n:
            raise ValueError("Scalar out of bounds: %i" % plaintext)

        if plaintext >= self.public_key.n - self.public_key.max_int:
            # Very large plaintext, play a sneaky trick using inverses
            neg_c = gmpy_math.invert(self.ciphertext(False), self.public_key.nsquare)
            neg_scalar = self.public_key.n - plaintext
            ciphertext = gmpy_math.powmod(neg_c, neg_scalar, self.public_key.nsquare)
        else:
            ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.nsquare)

        exponent = self.exponent + encode.exponent

        return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)

    def increase_exponent_to(self, new_exponent):
        """return PaillierEncryptedNumber:
           new PaillierEncryptedNumber with same value but having great exponent.
        """
        if new_exponent < self.exponent:
            raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent))

        factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent)
        new_encryptednumber = self.__mul__(factor)
        new_encryptednumber.exponent = new_exponent

        return new_encryptednumber

    def __align_exponent(self, x, y):
        """return x,y with same exponet
        """
        if x.exponent < y.exponent:
            x = x.increase_exponent_to(y.exponent)
        elif x.exponent > y.exponent:
            y = y.increase_exponent_to(x.exponent)

        return x, y

    def __add_scalar(self, scalar):
        """return PaillierEncryptedNumber: z = E(x) + y
        """
        encoded = FixedPointNumber.encode(scalar,
                                          self.public_key.n,
                                          self.public_key.max_int,
                                          max_exponent=self.exponent)

        return self.__add_fixpointnumber(encoded)

    def __add_fixpointnumber(self, encoded):
        """return PaillierEncryptedNumber: z = E(x) + FixedPointNumber(y)
        """
        if self.public_key.n != encoded.n:
            raise ValueError("Attempted to add numbers encoded against different public keys!")

        # their exponents must match, and align.
        x, y = self.__align_exponent(self, encoded)

        encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1)
        encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent)

        return encryptednumber

    def __add_encryptednumber(self, other):
        """return PaillierEncryptedNumber: z = E(x) + E(y)
        """
        if self.public_key != other.public_key:
            raise ValueError("add two numbers have different public key!")

        # their exponents must match, and align.
        x, y = self.__align_exponent(self, other)

        encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent)

        return encryptednumber

    def __raw_add(self, e_x, e_y, exponent):
        """return the integer E(x + y) given ints E(x) and E(y).
        """
        ciphertext =  e_x * e_y % self.public_key.nsquare

        return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)


./federatedml/secureprotol/fixedpoint.py
 L22~23:
  from arch.api.utils import log_utils
  LOGGER = log_utils.getLogger()
FixedPointNumber->encode()
 L86~88:
  result = cls(int_fixpoint % n, exponent, n, max_int)
  LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
  return result
FixedPointNumber->decode()
 L105~107:
  result = mantissa * pow(self.BASE, -self.exponent)
  LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
  return result
FixedPointNumber
 L271~273:
  def __str__(self):
   return "FixedPointNumber: encoding={}, exponent={}".format(
    hex(self.encoding), hex(self.exponent))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import math
import sys

import numpy as np

from arch.api.utils import log_utils
LOGGER = log_utils.getLogger()

class FixedPointNumber(object):
    """Represents a float or int fixedpoit encoding;.
    """
    BASE = 16
    LOG2_BASE = math.log(BASE, 2)
    FLOAT_MANTISSA_BITS = sys.float_info.mant_dig

    Q = 293973345475167247070445277780365744413

    def __init__(self, encoding, exponent, n=None, max_int=None):
        self.n = n
        self.max_int = max_int

        if self.n is None:
            self.n = self.Q
            self.max_int = self.Q // 3 - 1

        self.encoding = encoding
        self.exponent = exponent

    @classmethod
    def encode(cls, scalar, n=None, max_int=None, precision=None, max_exponent=None):
        """return an encoding of an int or float.
        """
        # Calculate the maximum exponent for desired precision
        exponent = None

        #  Too low value preprocess;
        #  avoid "OverflowError: int too large to convert to float"

        if np.abs(scalar) < 1e-200:
            scalar = 0

        if n is None:
            n = cls.Q
            max_int = cls.Q // 3 - 1

        if precision is None:
            if isinstance(scalar, int) or isinstance(scalar, np.int16) or \
                    isinstance(scalar, np.int32) or isinstance(scalar, np.int64):
                exponent = 0
            elif isinstance(scalar, float) or isinstance(scalar, np.float16) \
                    or isinstance(scalar, np.float32) or isinstance(scalar, np.float64):
                flt_exponent = math.frexp(scalar)[1]
                lsb_exponent = cls.FLOAT_MANTISSA_BITS - flt_exponent
                exponent = math.floor(lsb_exponent / cls.LOG2_BASE)
            else:
                raise TypeError("Don't know the precision of type %s."
                                % type(scalar))
        else:
            exponent = math.floor(math.log(precision, cls.BASE))

        if max_exponent is not None:
            exponent = max(max_exponent, exponent)

        int_fixpoint = int(round(scalar * pow(cls.BASE, exponent)))

        if abs(int_fixpoint) > max_int:
            raise ValueError('Integer needs to be within +/- %d but got %d'
                             % (max_int, int_fixpoint))

        result = cls(int_fixpoint % n, exponent, n, max_int)
		LOGGER.debug("FixedPointNumber->encode(): value={}, FP={}".format(scalar, result))
		return result

    def decode(self):
        """return decode plaintext.
        """
        if self.encoding >= self.n:
            # Should be mod n
            raise ValueError('Attempted to decode corrupted number')
        elif self.encoding <= self.max_int:
            # Positive
            mantissa = self.encoding
        elif self.encoding >= self.n - self.max_int:
            # Negative
            mantissa = self.encoding - self.n
        else:
            raise OverflowError('Overflow detected in decode number')

        result = mantissa * pow(self.BASE, -self.exponent)
		LOGGER.debug("FixedPointNumber->decode(): FP={}, value={}".format(self, result))
		return result

    def increase_exponent_to(self, new_exponent):
        """return FixedPointNumber: new encoding with same value but having great exponent.
        """
        if new_exponent < self.exponent:
            raise ValueError('New exponent %i should be greater than'
                             'old exponent %i' % (new_exponent, self.exponent))

        factor = pow(self.BASE, new_exponent - self.exponent)
        new_encoding = self.encoding * factor % self.n

        return FixedPointNumber(new_encoding, new_exponent, self.n, self.max_int)

    def __align_exponent(self, x, y):
        """return x,y with same exponet
        """
        if x.exponent < y.exponent:
            x = x.increase_exponent_to(y.exponent)
        elif x.exponent > y.exponent:
            y = y.increase_exponent_to(x.exponent)

        return x, y

    def __truncate(self, a):
        scalar = a.decode()
        return FixedPointNumber.encode(scalar)

    def __add__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__add_fixpointnumber(other)
        else:
            return self.__add_scalar(other)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__sub_fixpointnumber(other)
        else:
            return self.__sub_scalar(other)

    def __rsub__(self, other):
        x = self.__sub__(other)
        x = -1 * x.decode()
        return self.encode(x)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __mul__(self, other):
        if isinstance(other, FixedPointNumber):
            return self.__mul_fixpointnumber(other)
        else:
            return self.__mul_scalar(other)

    def __truediv__(self, other):
        if isinstance(other, FixedPointNumber):
            scalar = other.decode()
        else:
            scalar = other

        return self.__mul__(1 / scalar)

    def __rtruediv__(self, other):
        res = 1.0 / self.__truediv__(other).decode()
        return FixedPointNumber.encode(res)

    def __lt__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x < y:
            return True
        else:
            return False

    def __gt__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x > y:
            return True
        else:
            return False

    def __le__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x <= y:
            return True
        else:
            return False

    def __ge__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other

        if x >= y:
            return True
        else:
            return False

    def __eq__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x == y:
            return True
        else:
            return False

    def __ne__(self, other):
        x = self.decode()
        if isinstance(other, FixedPointNumber):
            y = other.decode()
        else:
            y = other
        if x != y:
            return True
        else:
            return False

    def __add_fixpointnumber(self, other):
        x, y = self.__align_exponent(self, other)
        encoding = (x.encoding + y.encoding) % self.Q
        return FixedPointNumber(encoding, x.exponent)

    def __add_scalar(self, scalar):
        encoded = self.encode(scalar)
        return self.__add_fixpointnumber(encoded)

    def __sub_fixpointnumber(self, other):
        scalar = -1 * other.decode()
        return self.__add_scalar(scalar)

    def __sub_scalar(self, scalar):
        scalar = -1 * scalar
        return self.__add_scalar(scalar)

    def __mul_fixpointnumber(self, other):
        encoding = (self.encoding * other.encoding) % self.Q
        exponet = self.exponent + other.exponent
        mul_fixedpoint = FixedPointNumber(encoding, exponet)
        truncate_mul_fixedpoint = self.__truncate(mul_fixedpoint)
        return truncate_mul_fixedpoint

    def __mul_scalar(self, scalar):
        encoded = self.encode(scalar)
        return self.__mul_fixpointnumber(encoded)
     
    def __str__(self):
        return "FixedPointNumber: encoding={}, exponent={}".format(
            hex(self.encoding), hex(self.exponent))


./federatedml/optim/gradient/hetero_linear_model_gradient.py
Guest->compute_gradient_procedure()
 L178:
  LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L181:
  LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Guest->get_host_forward()
 L188:
  LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Guest->remote_fore_gradient()
 L193:
  LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->update_gradient()
 L197:
  LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L199:
  LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
Host->compute_gradient_procedure()
 L234:
  LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
 L236:
  LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))
 L244:
  LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L247:
  LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
Host->remote_host_forward()
 L278:
  LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
Host->get_fore_gradient()
 L282:
  LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
Host->update_gradient()
 L287:
  LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
 L289:
  LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import functools

import numpy as np
import scipy.sparse as sp

from federatedml.feature.sparse_vector import SparseVector
from federatedml.statistic import data_overview
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util import fate_operator


def __compute_partition_gradient(data, fit_intercept=True, is_sparse=False):
    """
    Compute hetero regression gradient for:
    gradient = ∑d*x, where d is fore_gradient which differ from different algorithm
    Parameters
    ----------
    data: DTable, include fore_gradient and features
    fit_intercept: bool, if model has interception or not. Default True

    Returns
    ----------
    numpy.ndarray
        hetero regression model gradient
    """
    feature = []
    fore_gradient = []

    if is_sparse:
        row_indice = []
        col_indice = []
        data_value = []

        row = 0
        feature_shape = None
        for key, (sparse_features, d) in data:
            fore_gradient.append(d)
            assert isinstance(sparse_features, SparseVector)
            if feature_shape is None:
                feature_shape = sparse_features.get_shape()
            for idx, v in sparse_features.get_all_data():
                col_indice.append(idx)
                row_indice.append(row)
                data_value.append(v)
            row += 1
        if feature_shape is None or feature_shape == 0:
            return 0
        sparse_matrix = sp.csr_matrix((data_value, (row_indice, col_indice)), shape=(row, feature_shape))
        fore_gradient = np.array(fore_gradient)

        # gradient = sparse_matrix.transpose().dot(fore_gradient).tolist()
        gradient = fate_operator.dot(sparse_matrix.transpose(), fore_gradient).tolist()
        if fit_intercept:
            bias_grad = np.sum(fore_gradient)
            gradient.append(bias_grad)
            # LOGGER.debug("In first method, gradient: {}, bias_grad: {}".format(gradient, bias_grad))
        return np.array(gradient)

    else:
        for key, value in data:
            feature.append(value[0])
            fore_gradient.append(value[1])
        feature = np.array(feature)
        fore_gradient = np.array(fore_gradient)
        if feature.shape[0] <= 0:
            return 0

        gradient = fate_operator.dot(feature.transpose(), fore_gradient)
        gradient = gradient.tolist()
        if fit_intercept:
            bias_grad = np.sum(fore_gradient)
            gradient.append(bias_grad)
        return np.array(gradient)


def compute_gradient(data_instances, fore_gradient, fit_intercept):
    """
    Compute hetero-regression gradient
    Parameters
    ----------
    data_instances: DTable, input data
    fore_gradient: DTable, fore_gradient
    fit_intercept: bool, if model has intercept or not

    Returns
    ----------
    DTable
        the hetero regression model's gradient
    """
    feat_join_grad = data_instances.join(fore_gradient,
                                         lambda d, g: (d.features, g))
    is_sparse = data_overview.is_sparse_data(data_instances)
    f = functools.partial(__compute_partition_gradient,
                          fit_intercept=fit_intercept,
                          is_sparse=is_sparse)
    gradient_partition = feat_join_grad.applyPartitions(f)
    gradient_partition = gradient_partition.reduce(lambda x, y: x + y)

    gradient = gradient_partition / data_instances.count()

    return gradient


class HeteroGradientBase(object):
    def compute_gradient_procedure(self, *args):
        raise NotImplementedError("Should not call here")

    def set_total_batch_nums(self, total_batch_nums):
        """	
        Use for sqn gradient.	
        """
        pass


class Guest(HeteroGradientBase):
    def __init__(self):
        self.host_forwards = None
        self.forwards = None
        self.aggregated_forwards = None

    def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
                                guest_gradient_transfer, guest_optim_gradient_transfer):
        self.host_forward_transfer = host_forward_transfer
        self.fore_gradient_transfer = fore_gradient_transfer
        self.unilateral_gradient_transfer = guest_gradient_transfer
        self.unilateral_optim_gradient_transfer = guest_optim_gradient_transfer

    def compute_and_aggregate_forwards(self, data_instances, model_weights,
                                       encrypted_calculator, batch_index, current_suffix, offset=None):
        raise NotImplementedError("Function should not be called here")

    def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights, optimizer,
                                   n_iter_, batch_index, offset=None):
        """
          Linear model gradient procedure
          Step 1: get host forwards which differ from different algorithm
                  For Logistic Regression and Linear Regression: forwards = wx
                  For Poisson Regression, forwards = exp(wx)

          Step 2: Compute self forwards and aggregate host forwards and get d = fore_gradient

          Step 3: Compute unilateral gradient = ∑d*x,

          Step 4: Send unilateral gradients to arbiter and received the optimized and decrypted gradient.
          """
        current_suffix = (n_iter_, batch_index)
        # self.host_forwards = self.get_host_forward(suffix=current_suffix)

        fore_gradient = self.compute_and_aggregate_forwards(data_instances, model_weights, encrypted_calculator,
                                                            batch_index, current_suffix, offset)

        self.remote_fore_gradient(fore_gradient, suffix=current_suffix)

        unilateral_gradient = compute_gradient(data_instances,
                                               fore_gradient,
                                               model_weights.fit_intercept)
        if optimizer is not None:
            unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
        LOGGER.debug("Guest computed: [[d]]x^G={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
        
        LOGGER.debug("Guest computed: [[d]]x^G+[[λΘ^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        return optimized_gradient, fore_gradient, self.host_forwards

    def get_host_forward(self, suffix=tuple()):
        host_forward = self.host_forward_transfer.get(idx=-1, suffix=suffix)
        LOGGER.debug("Guest received from Host: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
        return host_forward

    def remote_fore_gradient(self, fore_gradient, suffix=tuple()):
        self.fore_gradient_transfer.remote(obj=fore_gradient, role=consts.HOST, idx=-1, suffix=suffix)
        LOGGER.debug("Guest sent to Host: [[d]]={}, type={}".format(fore_gradient, type(fore_gradient)))

    def update_gradient(self, unilateral_gradient, suffix=tuple()):
        self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Guest sent to Arbiter: [[g^G]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Guest received from Arbiter: g^G={}, type={}".format(optimized_gradient, type(optimized_gradient)))
        return optimized_gradient


class Host(HeteroGradientBase):
    def __init__(self):
        self.forwards = None
        self.fore_gradient = None

    def _register_gradient_sync(self, host_forward_transfer, fore_gradient_transfer,
                                host_gradient_transfer, host_optim_gradient_transfer):
        self.host_forward_transfer = host_forward_transfer
        self.fore_gradient_transfer = fore_gradient_transfer
        self.unilateral_gradient_transfer = host_gradient_transfer
        self.unilateral_optim_gradient_transfer = host_optim_gradient_transfer

    def compute_forwards(self, data_instances, model_weights):
        raise NotImplementedError("Function should not be called here")

    def compute_unilateral_gradient(self, data_instances, fore_gradient, model_weights, optimizer):
        raise NotImplementedError("Function should not be called here")

    def compute_gradient_procedure(self, data_instances, encrypted_calculator, model_weights,
                                   optimizer,
                                   n_iter_, batch_index):
        """
        Linear model gradient procedure
        Step 1: get host forwards which differ from different algorithm
                For Logistic Regression: forwards = wx


        """
        current_suffix = (n_iter_, batch_index)

        self.forwards = self.compute_forwards(data_instances, model_weights)
        LOGGER.debug("Host computed: u^H={}, type={}".format(list(self.forwards.collect()), type(self.forwards)))
        encrypted_forward = encrypted_calculator[batch_index].encrypt(self.forwards)
        LOGGER.debug("Host encrypted: [[u^H]]={}, type={}".format(list(encrypted_forward.collect()), type(encrypted_forward)))

        self.remote_host_forward(encrypted_forward, suffix=current_suffix)
        fore_gradient = self.get_fore_gradient(suffix=current_suffix)

        unilateral_gradient = compute_gradient(data_instances,
                                               fore_gradient,
                                               model_weights.fit_intercept)
        LOGGER.debug("Host computed: [[d]]x^H={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        if optimizer is not None:
            unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
        LOGGER.debug("Host computed: [[d]]x^H+[[λΘ^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        
        optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
        return optimized_gradient, fore_gradient

    def compute_sqn_forwards(self, data_instances, delta_s, cipher_operator):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
        define forward_hess = ∑(0.25 * x * s)
        """
        sqn_forwards = data_instances.mapValues(
            lambda v: cipher_operator.encrypt(fate_operator.vec_dot(v.features, delta_s.coef_) + delta_s.intercept_))
        # forward_sum = sqn_forwards.reduce(reduce_add)
        return sqn_forwards

    def compute_forward_hess(self, data_instances, delta_s, forward_hess):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
        define forward_hess = (0.25 * x * s)
        """
        hess_vector = compute_gradient(data_instances,
                                       forward_hess,
                                       delta_s.fit_intercept)
        return np.array(hess_vector)

    def remote_host_forward(self, host_forward, suffix=tuple()):
        self.host_forward_transfer.remote(obj=host_forward, role=consts.GUEST, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Guest: [[u^H]]={}, type={}".format(host_forward, type(host_forward)))
        
    def get_fore_gradient(self, suffix=tuple()):
        host_forward = self.fore_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Host received from Guest: fore_gradient={}, type={}".format(host_forward, type(host_forward)))
        return host_forward

    def update_gradient(self, unilateral_gradient, suffix=tuple()):
        self.unilateral_gradient_transfer.remote(unilateral_gradient, role=consts.ARBITER, idx=0, suffix=suffix)
        LOGGER.debug("Host sent to Arbiter: [[g^H]]={}, type={}".format(unilateral_gradient, type(unilateral_gradient)))
        optimized_gradient = self.unilateral_optim_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.debug("Host received from Arbiter: g^H={}, type={}".format(optimized_gradient, type(optimized_gradient)))
        return optimized_gradient


class Arbiter(HeteroGradientBase):
    def __init__(self):
        self.has_multiple_hosts = False

    def _register_gradient_sync(self, guest_gradient_transfer, host_gradient_transfer,
                                guest_optim_gradient_transfer, host_optim_gradient_transfer):
        self.guest_gradient_transfer = guest_gradient_transfer
        self.host_gradient_transfer = host_gradient_transfer
        self.guest_optim_gradient_transfer = guest_optim_gradient_transfer
        self.host_optim_gradient_transfer = host_optim_gradient_transfer

    def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
        """
        Compute gradients.
        Received local_gradients from guest and hosts. Merge and optimize, then separate and remote back.
        Parameters
        ----------
        cipher_operator: Use for encryption

        optimizer: optimizer that get delta gradient of this iter

        n_iter_: int, current iter nums

        batch_index: int, use to obtain current encrypted_calculator

        """
        current_suffix = (n_iter_, batch_index)

        host_gradients, guest_gradient = self.get_local_gradient(current_suffix)

        if len(host_gradients) > 1:
            self.has_multiple_hosts = True

        host_gradients = [np.array(h) for h in host_gradients]
        guest_gradient = np.array(guest_gradient)

        size_list = [h_g.shape[0] for h_g in host_gradients]
        size_list.append(guest_gradient.shape[0])

        gradient = np.hstack((h for h in host_gradients))
        gradient = np.hstack((gradient, guest_gradient))

        grad = np.array(cipher_operator.decrypt_list(gradient))

        # LOGGER.debug("In arbiter compute_gradient_procedure, before apply grad: {}, size_list: {}".format(
        #     grad, size_list
        # ))

        delta_grad = optimizer.apply_gradients(grad)

        # LOGGER.debug("In arbiter compute_gradient_procedure, delta_grad: {}".format(
        #     delta_grad
        # ))
        separate_optim_gradient = self.separate(delta_grad, size_list)
        # LOGGER.debug("In arbiter compute_gradient_procedure, separated gradient: {}".format(
        #     separate_optim_gradient
        # ))
        host_optim_gradients = separate_optim_gradient[: -1]
        guest_optim_gradient = separate_optim_gradient[-1]

        self.remote_local_gradient(host_optim_gradients, guest_optim_gradient, current_suffix)
        return delta_grad

    @staticmethod
    def separate(value, size_list):
        """
        Separate value in order to several set according size_list
        Parameters
        ----------
        value: list or ndarray, input data
        size_list: list, each set size

        Returns
        ----------
        list
            set after separate
        """
        separate_res = []
        cur = 0
        for size in size_list:
            separate_res.append(value[cur:cur + size])
            cur += size
        return separate_res

    def get_local_gradient(self, suffix=tuple()):
        host_gradients = self.host_gradient_transfer.get(idx=-1, suffix=suffix)
        LOGGER.info("Get host_gradient from Host")

        guest_gradient = self.guest_gradient_transfer.get(idx=0, suffix=suffix)
        LOGGER.info("Get guest_gradient from Guest")
        return host_gradients, guest_gradient

    def remote_local_gradient(self, host_optim_gradients, guest_optim_gradient, suffix=tuple()):
        for idx, host_optim_gradient in enumerate(host_optim_gradients):
            self.host_optim_gradient_transfer.remote(host_optim_gradient,
                                                     role=consts.HOST,
                                                     idx=idx,
                                                     suffix=suffix)

        self.guest_optim_gradient_transfer.remote(guest_optim_gradient,
                                                  role=consts.GUEST,
                                                  idx=0,
                                                  suffix=suffix)


./federatedml/optim/gradient/hetero_linr_gradient_and_loss.py
Guest->compute_and_aggregate_forwards()
 L65:
  LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
 L67:
  LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
 L71:
  LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
 L73:
  LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
Guest->compute_loss()
 L98:
  LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
 L100:
  LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
 L103:
  LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
 L105:
  LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
 L108:
  LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
Host->compute_loss()
 L155:
  LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
 L157:
  LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
 L162:
  LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
 L164:
  LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))

#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import numpy as np

from federatedml.framework.hetero.sync import loss_sync
from federatedml.optim.gradient import hetero_linear_model_gradient
from federatedml.util import LOGGER
from federatedml.util.fate_operator import reduce_add, vec_dot


class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):

    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.host_forward,
                                     transfer_variables.fore_gradient,
                                     transfer_variables.guest_gradient,
                                     transfer_variables.guest_optim_gradient)

        self._register_loss_sync(transfer_variables.host_loss_regular,
                                 transfer_variables.loss,
                                 transfer_variables.loss_intermediate)

    def compute_and_aggregate_forwards(self, data_instances, model_weights,
                                       encrypted_calculator, batch_index, current_suffix, offset=None):
        """
        Compute gradients:
        gradient = (1/N)*\sum(wx -y)*x

        Define wx as guest_forward or host_forward
        Define (wx-y) as fore_gradient
        Parameters
        ----------
        data_instances: DTable of Instance, input data

        model_weights: LinearRegressionWeights
            Stores coef_ and intercept_ of model

        encrypted_calculator: Use for different encrypted methods

        offset: Used in Poisson only.

        batch_index: int, use to obtain current encrypted_calculator index:

        current_suffix: tuple or string. Used in transfer_variable
        """
        wx = data_instances.mapValues(
            lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
        self.forwards = wx
        LOGGER.debug("Guest computed forwards: u^G={}, type={}".format(self.forwards, type(self.forwards)))
        self.aggregated_forwards = encrypted_calculator[batch_index].encrypt(wx)
        LOGGER.debug("Guest encrypted forwards: [[u^G]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
        self.host_forwards = self.get_host_forward(suffix=current_suffix)

        for host_forward in self.host_forwards:
            LOGGER.debug("Guest computed: [[u^G+u^H]]={}, type={}".format(self.aggregated_forwards, type(self.aggregated_forwards)))
            self.aggregated_forwards = self.aggregated_forwards.join(host_forward, lambda g, h: g + h)
        LOGGER.debug("Guest computed: [[d]]=[[u^G+u^H-y]]={}, type={}".format(fore_gradient, type(fore_gradient)))
        fore_gradient = self.aggregated_forwards.join(data_instances, lambda wx, d: wx - d.label)
        return fore_gradient

    def compute_loss(self, data_instances, n_iter_, batch_index, loss_norm=None):
        '''
        Compute hetero linr loss:
            loss = (1/N)*\sum(wx-y)^2 where y is label, w is model weight and x is features
        (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + wx_g - y)
        '''
        current_suffix = (n_iter_, batch_index)
        n = data_instances.count()
        loss_list = []
        host_wx_squares = self.get_host_loss_intermediate(current_suffix)

        if loss_norm is not None:
            host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
        else:
            host_loss_regular = []
        if len(self.host_forwards) > 1:
            LOGGER.info("More than one host exist, loss is not available")
        else:
            host_forward = self.host_forwards[0]
            host_wx_square = host_wx_squares[0]
            wxy = self.forwards.join(data_instances, lambda wx, d: wx - d.label)
            LOGGER.debug("Guest computed: (u^G-y)={}, type={}".format(wxy, type(wxy)))
            wxy_square = wxy.mapValues(lambda x: np.square(x)).reduce(reduce_add)
            LOGGER.debug("Guest computed: (u^G-y)^2={}, type={}".format(wxy_square, type(wxy_square)))
            loss_gh = wxy.join(host_forward, lambda g, h: g * h).reduce(reduce_add)
            
            LOGGER.debug("Guest computed: [[u^H(u^G-y)]]={}, type={}".format(loss_gh, type(loss_gh)))
            loss = (wxy_square + host_wx_square + 2 * loss_gh) / (2 * n)
            LOGGER.debug("Guest computed: [[(u^H+u^G-y)^2]]={}, type={}".format(loss, type(loss)))
            if loss_norm is not None:
                loss = loss + loss_norm + host_loss_regular[0]
                LOGGER.debug("Guest computed: [[L]]={}, type={}".format(loss, type(loss)))
            loss_list.append(loss)
        LOGGER.debug("In compute_loss, loss list are: {}".format(loss_list))
        self.sync_loss_info(loss_list, suffix=current_suffix)

    def compute_forward_hess(self, data_instances, delta_s, host_forwards):
        """
        To compute Hessian matrix, y, s are needed.
        g = (1/N)*∑(wx - y) * x
        y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(x * s) * x
        define forward_hess = (1/N)*∑(x * s)
        """
        forwards = data_instances.mapValues(
            lambda v: (np.dot(v.features, delta_s.coef_) + delta_s.intercept_))
        for host_forward in host_forwards:
            forwards = forwards.join(host_forward, lambda g, h: g + h)
        hess_vector = hetero_linear_model_gradient.compute_gradient(data_instances,
                                                                    forwards,
                                                                    delta_s.fit_intercept)
        return forwards, np.array(hess_vector)


class Host(hetero_linear_model_gradient.Host, loss_sync.Host):

    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.host_forward,
                                     transfer_variables.fore_gradient,
                                     transfer_variables.host_gradient,
                                     transfer_variables.host_optim_gradient)
        self._register_loss_sync(transfer_variables.host_loss_regular,
                                 transfer_variables.loss,
                                 transfer_variables.loss_intermediate)

    def compute_forwards(self, data_instances, model_weights):
        wx = data_instances.mapValues(lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
        return wx

    def compute_loss(self, model_weights, optimizer, n_iter_, batch_index, cipher_operator):
        '''
        Compute htero linr loss for:
            loss = (1/2N)*\sum(wx-y)^2 where y is label, w is model weight and x is features

            Note: (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + (wx_g - y))
        '''

        current_suffix = (n_iter_, batch_index)
        self_wx_square = self.forwards.mapValues(lambda x: np.square(x)).reduce(reduce_add)
        LOGGER.debug("Host computed: (u^H)^2={}, type={}".format(self_wx_square, type(self_wx_square)))
        en_wx_square = cipher_operator.encrypt(self_wx_square)
        LOGGER.debug("Host encrypted: [[(u^H)^2]]={}, type={}".format(en_wx_square, type(en_wx_square)))
        self.remote_loss_intermediate(en_wx_square, suffix=current_suffix)
        loss_regular = optimizer.loss_norm(model_weights)
        
        if loss_regular is not None:
            LOGGER.debug("Host computed: (λ/2)(Θ^H)^2={}, type={}".format(loss_regular, type(loss_regular)))
            en_loss_regular = cipher_operator.encrypt(loss_regular)
            LOGGER.debug("Host encrypted: [[(λ/2)(Θ^H)^2]]={}, type={}".format(en_loss_regular, type(en_loss_regular)))
            self.remote_loss_regular(en_loss_regular, suffix=current_suffix)


class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
    def register_gradient_procedure(self, transfer_variables):
        self._register_gradient_sync(transfer_variables.guest_gradient,
                                     transfer_variables.host_gradient,
                                     transfer_variables.guest_optim_gradient,
                                     transfer_variables.host_optim_gradient)
        self._register_loss_sync(transfer_variables.loss)

    def compute_loss(self, cipher, n_iter_, batch_index):
        """
        Decrypt loss from guest
        """
        current_suffix = (n_iter_, batch_index)
        loss_list = self.sync_loss_info(suffix=current_suffix)
        de_loss_list = cipher.decrypt_list(loss_list)
        return de_loss_list

 

3.数据预测:使用F纵向线性回归算法模型训练结果进行预测比较简单,中间参数也未采用加密手段进行保护,其流程及交互消息如下图所示。

 

源码:

在源码中添加日志:

./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_guest.py
HeteroLinRGuest->fit()
 L67:
  LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRGuest->predict()
 L148:
  LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
 L150:
  LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
 L157:
  LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts
from federatedml.util.io_check import assert_io_num_rows_equal


class HeteroLinRGuest(HeteroLinRBase):
    def __init__(self):
        super().__init__()
        self.data_batch_count = []
        # self.guest_forward = None
        self.role = consts.GUEST
        self.cipher = paillier_cipher.Guest()
        self.batch_generator = batch_generator.Guest()
        self.gradient_loss_operator = hetero_linr_gradient_and_loss.Guest()
        self.converge_procedure = convergence.Guest()
        self.encrypted_calculator = None

    @staticmethod
    def load_data(data_instance):
        """
        return data_instance as original
        Parameters
        ----------
        data_instance: DTable of Instance, input data
        """
        return data_instance

    def fit(self, data_instances, validate_data=None):
        """
        Train linR model of role guest
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_linR_guest fit")
        self._abnormal_detection(data_instances)
        self.header = self.get_header(data_instances)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        LOGGER.info("Generate mini-batch from input data")
        LOGGER.debug("Guest received from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
        self.batch_generator.initialize_batch_generator(data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)

        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data)

                # Start gradient procedure
                optim_guest_gradient, _, _ = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index
                )

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(data_instances, self.n_iter_, batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient)
                batch_index += 1
                # LOGGER.debug(
                #     "model_weights, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed))

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
            LOGGER.info("iter: {},  is_converged: {}".format(self.n_iter_, self.is_converged))

            # LOGGER.debug("model weights is {}".format(self.model_weights.coef_))

            if self.validation_strategy:
                LOGGER.debug('LinR guest running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())

    @assert_io_num_rows_equal
    def predict(self, data_instances):
        """
        Prediction of linR
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        predict_param: PredictParam, the setting of prediction.

        Returns
        ----------
        DTable
            include input data label, predict results
        """
        LOGGER.info("Start predict ...")
        self._abnormal_detection(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        data_features = self.transform(data_instances)
        pred = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
        LOGGER.debug("Guest computed: u^G={}, type={}".format(pred, type(pred)))
        host_preds = self.transfer_variable.host_partial_prediction.get(idx=-1)
        LOGGER.debug("Guest received from Host: u^H={}, type={}".format(host_preds, type(host_preds)))
        LOGGER.info("Get prediction from Host")

        for host_pred in host_preds:
            pred = pred.join(host_pred, lambda g, h: g + h)
        # predict_result = data_instances.join(pred, lambda d, pred: [d.label, pred, pred, {"label": pred}])
        predict_result = self.predict_score_to_output(data_instances=data_instances, predict_score=pred,
                                                      classes=None)
        LOGGER.debug("Guest computed: y={}, type={}".format(predict_result, type(predict_result)))
        return predict_result

./federatedml/linear_model/linear_regression/hetero_linear_regression/hetero_linr_host.py
HeteroLinRHost->fit()
 L58:
  LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
HeteroLinRHost->predict()
 L131:
  LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
 L133:
  LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from federatedml.framework.hetero.procedure import convergence
from federatedml.framework.hetero.procedure import paillier_cipher, batch_generator
from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.linear_model.linear_regression.hetero_linear_regression.hetero_linr_base import HeteroLinRBase
from federatedml.optim.gradient import hetero_linr_gradient_and_loss
from federatedml.secureprotol import EncryptModeCalculator
from federatedml.util import LOGGER
from federatedml.util import consts


class HeteroLinRHost(HeteroLinRBase):
    def __init__(self):
        super(HeteroLinRHost, self).__init__()
        self.batch_num = None
        self.batch_index_list = []
        self.role = consts.HOST

        self.cipher = paillier_cipher.Host()
        self.batch_generator = batch_generator.Host()
        self.gradient_loss_operator = hetero_linr_gradient_and_loss.Host()
        self.converge_procedure = convergence.Host()
        self.encrypted_calculator = None

    def fit(self, data_instances, validate_data=None):
        """
        Train linear regression model of role host
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_linR host")
        self._abnormal_detection(data_instances)

        self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.header = self.get_header(data_instances)
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        self.batch_generator.initialize_batch_generator(data_instances)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)
        LOGGER.debug("Host receives from Arbiter: PK={}".format(self.cipher_operator.get_public_key()))
        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)
        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            self.optimizer.set_iters(self.n_iter_)
            batch_data_generator = self.batch_generator.generate_batch_data()
            batch_index = 0
            for batch_data in batch_data_generator:
                batch_feat_inst = self.transform(batch_data)
                optim_host_gradient, _ = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index)

                self.gradient_loss_operator.compute_loss(self.model_weights, self.optimizer, self.n_iter_, batch_index,
                                                         self.cipher_operator)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_host_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))

            LOGGER.info("Get is_converged flag from arbiter:{}".format(self.is_converged))

            if self.validation_strategy:
                LOGGER.debug('LinR host running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))
            if self.is_converged:
                break
        if not self.is_converged:
            LOGGER.info("Reach max iter {}, train model finish!".format(self.max_iter))
        if self.validation_strategy and self.validation_strategy.has_saved_best_model():
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
        # LOGGER.debug(f"summary content is: {self.summary()}")

    def predict(self, data_instances):
        """
        Prediction of linR
        Parameters
        ----------
        data_instances:DTable of Instance, input data
        """
        self.transfer_variable.host_partial_prediction.disable_auto_clean()

        LOGGER.info("Start predict ...")

        self._abnormal_detection(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        data_features = self.transform(data_instances)

        pred_host = self.compute_wx(data_features, self.model_weights.coef_, self.model_weights.intercept_)
        LOGGER.debug("Guest computed: u^H={}, type={}".format(pred_host, type(pred_host)))
        self.transfer_variable.host_partial_prediction.remote(pred_host, role=consts.GUEST, idx=0)
        LOGGER.debug("Host sent to Guest: u^H={}, type={}".format(pred_host, type(pred_host)))
        LOGGER.info("Remote partial prediction to Guest")

 

 

文章来自个人专栏
多方隐私计算
1 文章 | 1 订阅
0条评论
0 / 1000
请输入你的评论
0
0