Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation #196

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions experimental/pir/piano/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ psi_cc_library(
hdrs = ["util.h"],
deps = [
"@yacl//yacl/crypto/aes:aes_intrinsics",
"@yacl//yacl/crypto/rand",
],
)

Expand Down Expand Up @@ -67,7 +68,6 @@ psi_cc_library(
":piano_cc_proto",
":serialize",
":util",
"@yacl//yacl/crypto/rand",
"@yacl//yacl/link:context",
],
)
Expand All @@ -80,7 +80,6 @@ psi_cc_test(
":client",
":server",
":util",
"@yacl//yacl/crypto/rand",
"@yacl//yacl/link:context",
"@yacl//yacl/link:test_util",
],
Expand Down
101 changes: 52 additions & 49 deletions experimental/pir/piano/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,37 @@ double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num,
}

QueryServiceClient::QueryServiceClient(
const uint64_t db_size, const uint64_t thread_num,
std::shared_ptr<yacl::link::Context> context)
: db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) {
const uint64_t entry_num, const uint64_t thread_num,
const uint64_t entry_size, std::shared_ptr<yacl::link::Context> context)
: entry_num_(entry_num),
thread_num_(thread_num),
context_(std::move(context)),
entry_size_(entry_size) {
Initialize();
InitializeLocalSets();
}

void QueryServiceClient::Initialize() {
std::mt19937_64 rng(yacl::crypto::FastRandU64());

master_key_ = RandKey(rng);
long_key_ = GetLongKey(&master_key_);
master_key_ = SecureRandKey();
long_key_ = GetLongKey(master_key_);

// Q = sqrt(n) * ln(n)
totalQueryNum =
static_cast<uint64_t>(std::sqrt(static_cast<double>(db_size_)) *
std::log(static_cast<double>(db_size_)));
total_query_num_ =
static_cast<uint64_t>(std::sqrt(static_cast<double>(entry_num_)) *
std::log(static_cast<double>(entry_num_)));

std::tie(chunk_size_, set_size_) = GenParams(db_size_);
std::tie(chunk_size_, set_size_) = GenParams(entry_num_);

primary_set_num_ =
primaryNumParam(static_cast<double>(totalQueryNum),
static_cast<double>(chunk_size_), FailureProbLog2 + 1);
primaryNumParam(static_cast<double>(total_query_num_),
static_cast<double>(chunk_size_), kFailureProbLog2 + 1);
// if localSetNum is not a multiple of thread_num_ then we need to add some
// padding
primary_set_num_ =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thread_num是不是要做检验?如果是0咋办

(primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_;

backup_set_num_per_chunk_ =
3 * static_cast<uint64_t>(static_cast<double>(totalQueryNum) /
3 * static_cast<uint64_t>(static_cast<double>(total_query_num_) /
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的转换看着没必要

static_cast<double>(set_size_));
backup_set_num_per_chunk_ =
(backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_;
Expand All @@ -67,8 +68,17 @@ void QueryServiceClient::InitializeLocalSets() {
local_miss_elements_.clear();
uint32_t tagCounter = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

局部变量的命令风格统一用下划线来连接小写字符代码里还有很多类似的命名风格不统一的地方,请自己检查并修改。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已检查并修改


// Initialize primary_sets_
for (uint64_t j = 0; j < primary_set_num_; j++) {
primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false);
primary_sets_.emplace_back(tagCounter, DBEntry::ZeroEntry(entry_size_), 0,
false);
tagCounter += 1;
}

// Initialize local_backup_sets_
for (uint64_t i = 0; i < total_backup_set_num_; ++i) {
local_backup_sets_.emplace_back(tagCounter,
DBEntry::ZeroEntry(entry_size_));
tagCounter += 1;
}

Expand All @@ -77,6 +87,7 @@ void QueryServiceClient::InitializeLocalSets() {
local_replacement_groups_.clear();
local_replacement_groups_.reserve(set_size_);

// Initialize local_backup_set_groups_ and local_replacement_groups_
for (uint64_t i = 0; i < set_size_; i++) {
std::vector<std::reference_wrapper<LocalBackupSet>> backupSets;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着感觉使用LocalBackupSet的span就可以,没必要用vector<reference_wrapper>

for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) {
Expand All @@ -91,14 +102,6 @@ void QueryServiceClient::InitializeLocalSets() {
LocalReplacementGroup replacementGroup(0, indices, values);
local_replacement_groups_.emplace_back(std::move(replacementGroup));
}

for (uint64_t j = 0; j < set_size_; j++) {
for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) {
local_backup_set_groups_[j].sets[k].get() =
LocalBackupSet{tagCounter, ZeroEntry()};
tagCounter += 1;
}
}
}

void QueryServiceClient::FetchFullDB() {
Expand Down Expand Up @@ -145,17 +148,16 @@ void QueryServiceClient::FetchFullDB() {
std::lock_guard<std::mutex> lock(hitMapMutex);
hitMap[offset] = true;
}
DBEntryXorFromRaw(&primary_sets_[j].parity,
&chunk[offset * DBEntryLength]);
primary_sets_[j].parity.XorFromRaw(&chunk[offset * entry_size_]);
}

// update the parities for the backup hints
for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) {
const auto tmp =
PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i);
const auto offset = tmp & (chunk_size_ - 1);
DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct,
&chunk[offset * DBEntryLength]);
local_backup_sets_[j].parityAfterPuncture.XorFromRaw(
&chunk[offset * entry_size_]);
}
});
}
Expand All @@ -171,10 +173,10 @@ void QueryServiceClient::FetchFullDB() {
// empty.
for (uint64_t j = 0; j < chunk_size_; j++) {
if (!hitMap[j]) {
std::array<uint64_t, DBEntryLength> entry_slice{};
std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength],
DBEntryLength * sizeof(uint64_t));
const auto entry = DBEntryFromSlice(entry_slice);
std::vector<uint8_t> entry_slice(entry_size_);
std::memcpy(entry_slice.data(), &chunk[j * entry_size_],
entry_size_ * sizeof(uint8_t));
const auto entry = DBEntry::DBEntryFromSlice(entry_slice);
local_miss_elements_[j + (i * chunk_size_)] = entry;
}
}
Expand All @@ -185,30 +187,30 @@ void QueryServiceClient::FetchFullDB() {
const auto tag = local_backup_set_groups_[i].sets[k].get().tag;
const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i);
const auto offset = tmp & (chunk_size_ - 1);
DBEntryXorFromRaw(
&local_backup_set_groups_[i].sets[k].get().parityAfterPunct,
&chunk[offset * DBEntryLength]);
local_backup_set_groups_[i].sets[k].get().parityAfterPuncture.XorFromRaw(
&chunk[offset * entry_size_]);
}

// store the replacement
std::mt19937_64 rng(yacl::crypto::FastRandU64());
yacl::crypto::Prg<uint64_t> prg(yacl::crypto::SecureRandU64());
for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看太明白,这里应该对应论文Figrue-1中的,Offline 预处理阶段的第三个阶段的Update back table

论文的描述中,for i in {0, 1, , 2..} / {j} ,这里的 i 应该是遍历每一个chunk 吧?当然对于当前正在处理的 第 j 个 chunk 会跳过去,但是代码里好像没有看到这样的逻辑,而是把每一个i都无区别的用来做了这里的处理?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前是在多线程中处理所有chunk,出来后再XOR i-th chunk,以实现跳过的逻辑。但这样会有一些不必要的计算开销,已经修改为,在多线程内根据backup hintschunk间的对应关系,实现跳过处理。

// generate a random offset between 0 and ChunkSize - 1
const auto offset = rng() & (chunk_size_ - 1);
const auto offset = prg() & (chunk_size_ - 1);
local_replacement_groups_[i].indices[k] = offset + i * chunk_size_;
std::array<uint64_t, DBEntryLength> entry_slice{};
std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength],
DBEntryLength * sizeof(uint64_t));
local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice);
std::vector<uint8_t> entry_slice(entry_size_);
std::memcpy(entry_slice.data(), &chunk[offset * entry_size_],
entry_size_ * sizeof(uint8_t));
local_replacement_groups_[i].value[k] =
DBEntry::DBEntryFromSlice(entry_slice);
}
}
}

void QueryServiceClient::SendDummySet() const {
std::mt19937_64 rng(yacl::crypto::FastRandU64());
yacl::crypto::Prg<uint64_t> prg(yacl::crypto::SecureRandU64());
std::vector<uint64_t> randSet(set_size_);
for (uint64_t i = 0; i < set_size_; i++) {
randSet[i] = rng() % chunk_size_ + i * chunk_size_;
randSet[i] = prg() % chunk_size_ + i * chunk_size_;
}

// send the random dummy set to the server
Expand Down Expand Up @@ -249,7 +251,7 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) {
}
}

DBEntry xVal = ZeroEntry();
DBEntry xVal = DBEntry::ZeroEntry(entry_size_);

if (hitSetId == std::numeric_limits<uint64_t>::max()) {
if (local_miss_elements_.find(x) == local_miss_elements_.end()) {
Expand Down Expand Up @@ -301,9 +303,9 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) {
const auto& parity = std::get<0>(parityQueryResponse);

// recover the answer
xVal = primary_sets_[hitSetId].parity; // the parity of the hit set
DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set
DBEntryXor(&xVal, &repVal); // xor the replacement value
xVal = primary_sets_[hitSetId].parity; // the parity of the hit set
xVal.XorFromRaw(parity.data()); // xor the parity of the edited set
xVal.Xor(repVal); // xor the replacement value

// update the local cache
local_cache_[x] = xVal;
Expand All @@ -319,9 +321,10 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) {
primary_sets_[hitSetId].tag =
local_backup_set_groups_[chunkId].sets[consumed].get().tag;
// backup set doesn't XOR the chunk(x)-th element in preparation
DBEntryXor(
&xVal,
&local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct);
xVal.Xor(local_backup_set_groups_[chunkId]
.sets[consumed]
.get()
.parityAfterPuncture);
primary_sets_[hitSetId].parity = xVal;
primary_sets_[hitSetId].isProgrammed = true;
// for load balancing, the chunk(x)-th element differs from the one expanded
Expand Down
75 changes: 35 additions & 40 deletions experimental/pir/piano/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,64 @@

#include <spdlog/spdlog.h>

#include <cmath>
#include <cstdint>
#include <iostream>
#include <thread>
#include <utility>

#include "experimental/pir/piano/serialize.h"
#include "experimental/pir/piano/util.h"
#include "yacl/crypto/rand/rand.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/link/context.h"

namespace pir::piano {

class LocalSet {
public:
uint32_t tag; // the tag of the set
DBEntry parity;
uint64_t
programmedPoint; // identifier for the element replaced after refresh,
// differing from those expanded by PRFEval
bool isProgrammed;

LocalSet(const uint32_t tag, const DBEntry& parity,
const uint64_t programmed_point, const bool is_programmed)
struct LocalSet {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalSet, LocalBackupSet, LocalBackupSetGroup, LocalReplacementGroup

这几个类型主要拿来干嘛的?请添加有效注释说明一下。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释说明

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,修改后请提交最新的commit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已提交commit,请问我是否需要将代码从 psi/experimental/ 移动到 psi/experiment/ 目录下,我看已有实验性代码都在该目录下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要的,请移动到psi/experiment/pir/ 目录下。这一点非常抱歉,我在之前的comments中,写成了experimental

LocalSet(const uint32_t tag, DBEntry parity, const uint64_t programmed_point,
const bool is_programmed)
: tag(tag),
parity(parity),
parity(std::move(parity)),
programmedPoint(programmed_point),
isProgrammed(is_programmed) {}

uint32_t tag; // the tag of the set
DBEntry parity;
uint64_t programmedPoint; // identifier for the element replaced after
// refresh differing from those expanded by PRFEval
bool isProgrammed;
};

class LocalBackupSet {
public:
uint32_t tag;
DBEntry parityAfterPunct;
struct LocalBackupSet {
LocalBackupSet(const uint32_t tag, DBEntry parity_after_puncture)
: tag(tag), parityAfterPuncture(std::move(parity_after_puncture)) {}

LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct)
: tag(tag), parityAfterPunct(parity_after_punct) {}
uint32_t tag;
DBEntry parityAfterPuncture;
};

class LocalBackupSetGroup {
public:
uint64_t consumed;
std::vector<std::reference_wrapper<LocalBackupSet>> sets;

struct LocalBackupSetGroup {
LocalBackupSetGroup(
const uint64_t consumed,
const std::vector<std::reference_wrapper<LocalBackupSet>>& sets)
: consumed(consumed), sets(sets) {}
};

class LocalReplacementGroup {
public:
uint64_t consumed;
std::vector<uint64_t> indices;
std::vector<DBEntry> value;
std::vector<std::reference_wrapper<LocalBackupSet>> sets;
};

struct LocalReplacementGroup {
LocalReplacementGroup(const uint64_t consumed,
const std::vector<uint64_t>& indices,
const std::vector<DBEntry>& value)
: consumed(consumed), indices(indices), value(value) {}

uint64_t consumed;
std::vector<uint64_t> indices;
std::vector<DBEntry> value;
};

class QueryServiceClient {
public:
static constexpr uint64_t FailureProbLog2 = 40;
uint64_t totalQueryNum{};

QueryServiceClient(uint64_t db_size, uint64_t thread_num,
QueryServiceClient(uint64_t entry_num, uint64_t thread_num,
uint64_t entry_size,
std::shared_ptr<yacl::link::Context> context);

void Initialize();
Expand All @@ -78,24 +69,28 @@ class QueryServiceClient {
DBEntry OnlineSingleQuery(uint64_t x);
std::vector<DBEntry> OnlineMultipleQueries(
const std::vector<uint64_t>& queries);
uint64_t getTotalQueryNumber() const { return total_query_num_; };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

方法名首字母统一用大写,请自己检查其他地方类似的问题,然后修改掉。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已检查并修改


private:
uint64_t db_size_;
static constexpr uint64_t kFailureProbLog2 = 40;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个const没必要定义在这里吧?一般放在一个文件最上方的位置

uint64_t total_query_num_{};
uint64_t entry_num_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化

uint64_t thread_num_;
std::shared_ptr<yacl::link::Context> context_;

uint64_t chunk_size_{};
uint64_t set_size_{};
uint64_t entry_size_{};
uint64_t primary_set_num_{};
uint64_t backup_set_num_per_chunk_{};
uint64_t total_backup_set_num_{};
PrfKey master_key_{};
uint128_t master_key_{};
yacl::crypto::AES_KEY long_key_{};

std::vector<LocalSet> primary_sets_;
std::vector<LocalBackupSet> local_backup_sets_;
std::map<uint64_t, DBEntry> local_cache_;
std::map<uint64_t, DBEntry> local_miss_elements_;
std::unordered_map<uint64_t, DBEntry> local_cache_;
std::unordered_map<uint64_t, DBEntry> local_miss_elements_;
std::vector<LocalBackupSetGroup> local_backup_set_groups_;
std::vector<LocalReplacementGroup> local_replacement_groups_;
};
Expand Down
4 changes: 2 additions & 2 deletions experimental/pir/piano/piano.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ message FetchFullDbMsg {
message DbChunk {
uint64 chunk_id = 1;
uint64 chunk_size = 2;
repeated uint64 chunks = 3;
bytes chunks = 3;
}

message SetParityQueryMsg {
Expand All @@ -18,7 +18,7 @@ message SetParityQueryMsg {
}

message SetParityQueryResponse {
repeated uint64 parity = 1;
bytes parity = 1;
uint64 server_compute_time = 2;
}

Expand Down
Loading
Loading