Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation #196

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

cxiao129
Copy link

Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation

Copy link

github-actions bot commented Oct 29, 2024

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

@cxiao129
Copy link
Author

I have read the CLA Document and I hereby sign the CLA

Copy link
Contributor

@qxzhou1010 qxzhou1010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请把piano 的实现整体放在 psi/experimental/ 目录下

psi/piano/util.h Outdated
std::string Uint128ToBytes(uint128_t value);

// Generates a random 128-bit key using the provided RNG
PrfKey128 RandKey128(std::mt19937_64& rng);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用yacl::crypto::Prg ,其余涉及到 rng 到地方也替换为同样的Prg.

psi/piano/util.h Outdated

namespace psi::piano {

constexpr size_t DBEntrySize = 8; // has to be a multiple of 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里指的是数据库里一行数据的长度吗?单位是 bit 还是 字节?以及这里取8的出处是?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单位是字节,原本使用vector<uint64_t>表示数据,所以取8,现在改为使用vector<uint8_t>,同时根据参数确定数据库里一行数据的字节数

psi/piano/util.h Outdated
constexpr size_t DBEntrySize = 8; // has to be a multiple of 8
constexpr size_t DBEntryLength = DBEntrySize / 8;

using PrfKey128 = uint128_t;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引入这个头文件#include "yacl/base/int128.h" ,显式的使用yacl里提供的uint128_t 这个类型

psi/piano/util.h Outdated
PrfKey RandKey(std::mt19937_64& rng);

// Evaluates PRF using 128-bit key and returns a 64-bit result
uint64_t PRFEval128(const PrfKey128* key, uint64_t x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于 uint128_t 来说,可以直接pass by value ,不必使用pointer

psi/piano/util.h Outdated
constexpr size_t DBEntryLength = DBEntrySize / 8;

using PrfKey128 = uint128_t;
using DBEntry = std::array<uint64_t, DBEntryLength>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看到了后面大量的关于DBEntry 相关的操作。是否可以考虑,将DBEntry 进一步抽象为一个class 或者 struct

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已封装为class

psi/piano/util.h Outdated
using DBEntry = std::array<uint64_t, DBEntryLength>;
using PrfKey = PrfKey128;

uint128_t BytesToUint128(const std::string& bytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 yacl中ByteContainerView,或者std::vector<uint8_t> 来表示 bytes,而不是std::string


class QueryServiceClient {
public:
static constexpr uint64_t FailureProbLog2 = 40;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kFailureProbLog2

class QueryServiceClient {
public:
static constexpr uint64_t FailureProbLog2 = 40;
uint64_t totalQueryNum{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请放在private


std::vector<LocalSet> primary_sets_;
std::vector<LocalBackupSet> local_backup_sets_;
std::map<uint64_t, DBEntry> local_cache_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不要求有序,使用std::unordered_map 即可

psi/piano/util.h Outdated

class PRSetWithShortTag {
public:
uint32_t Tag;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放在 private 下,或者用 struct 即可

return entry;
}

uint64_t DefaultHash(uint64_t key) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个方法的出处是?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是为了生成测试数据,所以使用了简单高效的FNV哈希,已转移到测试代码中实现,逻辑和参数选取参考自https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function

@qxzhou1010 qxzhou1010 self-requested a review December 10, 2024 07:11
SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n",
params.db_size, DBEntrySize,
params.db_size * DBEntrySize / 1024 / 1024);
SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", entry_num,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个SPDLOG_INFO写错了,直接用{} 当占位符。测试跑出来是这样:

[2024-12-10 07:22:51.291] [info] [piano_test.cc:91] DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB

[2024-12-10 07:22:51.291] [info] [piano_test.cc:95] Chunk Size: %lu, Set Size: %lu

[2024-12-10 07:22:51.307] [info] [piano_test.cc:99] DB Real N: %lu
``

哥,你写了代码,好歹自己跑一下测试,看一眼啊

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在之前的测试中只运行了bazel test,导致输出信息不全,对此我深感抱歉。现在我已经在bazel run和bazel-bin的二进制文件下都进行了测试,确保信息能正确显示。


private:
uint64_t db_size_;
static constexpr uint64_t kFailureProbLog2 = 40;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个const没必要定义在这里吧?一般放在一个文件最上方的位置

@@ -78,24 +69,28 @@ class QueryServiceClient {
DBEntry OnlineSingleQuery(uint64_t x);
std::vector<DBEntry> OnlineMultipleQueries(
const std::vector<uint64_t>& queries);
uint64_t getTotalQueryNumber() const { return total_query_num_; };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

方法名首字母统一用大写,请自己检查其他地方类似的问题,然后修改掉。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已检查并修改

uint64_t db_size_;
static constexpr uint64_t kFailureProbLog2 = 40;
uint64_t total_query_num_{};
uint64_t entry_num_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化

@@ -67,8 +68,17 @@ void QueryServiceClient::InitializeLocalSets() {
local_miss_elements_.clear();
uint32_t tagCounter = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

局部变量的命令风格统一用下划线来连接小写字符代码里还有很多类似的命名风格不统一的地方,请自己检查并修改。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已检查并修改

@@ -45,13 +46,34 @@ std::vector<DBEntry> RunClient(QueryServiceClient& client,
return client.OnlineMultipleQueries(queries);
}

std::vector<uint8_t> FNVHash(uint64_t key) {
constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

常量命名风格用k开头,后面遵循驼峰规则,且不用定义在方法里。

const auto actual_query_num = params.is_total_query_num
? client.getTotalQueryNumber()
: params.query_num;
const auto queries = GenerateQueries(actual_query_num, entry_num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

生成 query 不应该是 Client 提供的方法之一吗?为什么这里是一个 free function ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是用于生成随机索引作为测试数据,在两方场景下,可以在客户端本地生成,再通过OnlineMultipleQueries函数将其作为参数读取。此外,需要模拟根据明文索引直接查询数据库,来验证PIR方案的正确性。

std::pair<uint64_t, uint64_t> GenParams(const uint64_t db_size) {
const double targetChunkSize = 2 * std::sqrt(static_cast<double>(db_size));
std::pair<uint64_t, uint64_t> GenParams(const uint64_t entry_num) {
const double targetChunkSize = 2 * std::sqrt(static_cast<double>(entry_num));
uint64_t ChunkSize = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量命名风格


namespace pir::piano {

uint64_t primaryNumParam(const double q, const double chunk_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

方法名风格问题,最好注明这个计算方式的来源

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释说明计算公式


namespace pir::piano {

struct LocalSet {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalSet, LocalBackupSet, LocalBackupSetGroup, LocalReplacementGroup

这几个类型主要拿来干嘛的?请添加有效注释说明一下。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释说明

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,修改后请提交最新的commit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已提交commit,请问我是否需要将代码从 psi/experimental/ 移动到 psi/experiment/ 目录下,我看已有实验性代码都在该目录下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要的,请移动到psi/experiment/pir/ 目录下。这一点非常抱歉,我在之前的comments中,写成了experimental


namespace pir::piano {

struct LocalSet {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要的,请移动到psi/experiment/pir/ 目录下。这一点非常抱歉,我在之前的comments中,写成了experimental

const uint64_t log2_k_ = 40;
// Converts log2_k_ from base-2 to natural logarithm using the change of base
// formula
const double natural_log_k_ = std::log(2) * static_cast<double>(log2_k_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个变量建议可以作为一个全局的const variable,放在某个.h 文件的 namespace 下。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为constexpr,并放在client.h文件的namespace

const uint64_t primary_set_per_thread =
((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1;
const uint64_t backup_set_per_thread =
((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面的计算不是已经表示向上取整了,为什么最外层还需要再+1 呢?

Copy link
Author

@cxiao129 cxiao129 Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的+1是多余的,已经删除了


// Store the replacement
yacl::crypto::Prg<uint64_t> prg(yacl::crypto::SecureRandU64());
for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看太明白,这里应该对应论文Figrue-1中的,Offline 预处理阶段的第三个阶段的Update back table

论文的描述中,for i in {0, 1, , 2..} / {j} ,这里的 i 应该是遍历每一个chunk 吧?当然对于当前正在处理的 第 j 个 chunk 会跳过去,但是代码里好像没有看到这样的逻辑,而是把每一个i都无区别的用来做了这里的处理?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前是在多线程中处理所有chunk,出来后再XOR i-th chunk,以实现跳过的逻辑。但这样会有一些不必要的计算开销,已经修改为,在多线程内根据backup hintschunk间的对应关系,实现跳过处理。

static_cast<double>(chunk_size_), kFailureProbLog2 + 1);
// if localSetNum is not a multiple of thread_num_ then we need to add some
// padding
primary_set_num_ =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thread_num是不是要做检验?如果是0咋办

(primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_;

backup_set_num_per_chunk_ =
3 * static_cast<uint64_t>(static_cast<double>(total_query_num_) /
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的转换看着没必要

}

void QueryServiceClient::InitializeLocalSets() {
primary_sets_.clear();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是初始化调用,clear是否有必要呢


// Initialize local_backup_set_groups_ and local_replacement_groups_
for (uint64_t i = 0; i < set_size_; i++) {
std::vector<std::reference_wrapper<LocalBackupSet>> backupSets;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着感觉使用LocalBackupSet的span就可以,没必要用vector<reference_wrapper>

@@ -0,0 +1,334 @@
#include "experiment/pir/piano/client.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请添加copyright文件头

}

void XorFromRaw(const uint8_t* src) {
for (size_t i = 0; i < k_length_; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用span,方便做长度校验

}

private:
size_t k_length_{};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k_length_是不是有必要呢,直接用data.size()行不行

uint64_t chunk_size = 1;

// Ensure chunk_size is a power of 2 and not smaller than target_chunk_size
while (chunk_size < static_cast<uint64_t>(target_chunk_size)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉用double比较更健壮

uint64_t entry_num, uint64_t thread_num,
uint64_t entry_size);

void Initialize();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩initialize如果只是构造函数调用建议放private

}

void QueryServiceClient::InitializeLocalSets() {
primary_sets_.clear();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是构造函数调用,没必要clear,如果需要有类似clear的需求可以单独做一个reset方法

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants