-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation #196
base: main
Are you sure you want to change the base?
Changes from 1 commit
500d454
22bab60
057d545
cd4cc90
0f90718
b25e91a
14b90b4
4f61534
5397010
af8e4a1
9f900cb
2b114c7
d04c226
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,36 +20,37 @@ double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, | |
} | ||
|
||
QueryServiceClient::QueryServiceClient( | ||
const uint64_t db_size, const uint64_t thread_num, | ||
std::shared_ptr<yacl::link::Context> context) | ||
: db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) { | ||
const uint64_t entry_num, const uint64_t thread_num, | ||
const uint64_t entry_size, std::shared_ptr<yacl::link::Context> context) | ||
: entry_num_(entry_num), | ||
thread_num_(thread_num), | ||
context_(std::move(context)), | ||
entry_size_(entry_size) { | ||
Initialize(); | ||
InitializeLocalSets(); | ||
} | ||
|
||
void QueryServiceClient::Initialize() { | ||
std::mt19937_64 rng(yacl::crypto::FastRandU64()); | ||
|
||
master_key_ = RandKey(rng); | ||
long_key_ = GetLongKey(&master_key_); | ||
master_key_ = SecureRandKey(); | ||
long_key_ = GetLongKey(master_key_); | ||
|
||
// Q = sqrt(n) * ln(n) | ||
totalQueryNum = | ||
static_cast<uint64_t>(std::sqrt(static_cast<double>(db_size_)) * | ||
std::log(static_cast<double>(db_size_))); | ||
total_query_num_ = | ||
static_cast<uint64_t>(std::sqrt(static_cast<double>(entry_num_)) * | ||
std::log(static_cast<double>(entry_num_))); | ||
|
||
std::tie(chunk_size_, set_size_) = GenParams(db_size_); | ||
std::tie(chunk_size_, set_size_) = GenParams(entry_num_); | ||
|
||
primary_set_num_ = | ||
primaryNumParam(static_cast<double>(totalQueryNum), | ||
static_cast<double>(chunk_size_), FailureProbLog2 + 1); | ||
primaryNumParam(static_cast<double>(total_query_num_), | ||
static_cast<double>(chunk_size_), kFailureProbLog2 + 1); | ||
// if localSetNum is not a multiple of thread_num_ then we need to add some | ||
// padding | ||
primary_set_num_ = | ||
(primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; | ||
|
||
backup_set_num_per_chunk_ = | ||
3 * static_cast<uint64_t>(static_cast<double>(totalQueryNum) / | ||
3 * static_cast<uint64_t>(static_cast<double>(total_query_num_) / | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的转换看着没必要 |
||
static_cast<double>(set_size_)); | ||
backup_set_num_per_chunk_ = | ||
(backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; | ||
|
@@ -67,8 +68,17 @@ void QueryServiceClient::InitializeLocalSets() { | |
local_miss_elements_.clear(); | ||
uint32_t tagCounter = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 局部变量的命令风格统一用下划线来连接小写字符 ,代码里还有很多类似的命名风格不统一的地方,请自己检查并修改。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已检查并修改 |
||
|
||
// Initialize primary_sets_ | ||
for (uint64_t j = 0; j < primary_set_num_; j++) { | ||
primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false); | ||
primary_sets_.emplace_back(tagCounter, DBEntry::ZeroEntry(entry_size_), 0, | ||
false); | ||
tagCounter += 1; | ||
} | ||
|
||
// Initialize local_backup_sets_ | ||
for (uint64_t i = 0; i < total_backup_set_num_; ++i) { | ||
local_backup_sets_.emplace_back(tagCounter, | ||
DBEntry::ZeroEntry(entry_size_)); | ||
tagCounter += 1; | ||
} | ||
|
||
|
@@ -77,6 +87,7 @@ void QueryServiceClient::InitializeLocalSets() { | |
local_replacement_groups_.clear(); | ||
local_replacement_groups_.reserve(set_size_); | ||
|
||
// Initialize local_backup_set_groups_ and local_replacement_groups_ | ||
for (uint64_t i = 0; i < set_size_; i++) { | ||
std::vector<std::reference_wrapper<LocalBackupSet>> backupSets; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看着感觉使用LocalBackupSet的span就可以,没必要用vector<reference_wrapper> |
||
for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { | ||
|
@@ -91,14 +102,6 @@ void QueryServiceClient::InitializeLocalSets() { | |
LocalReplacementGroup replacementGroup(0, indices, values); | ||
local_replacement_groups_.emplace_back(std::move(replacementGroup)); | ||
} | ||
|
||
for (uint64_t j = 0; j < set_size_; j++) { | ||
for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { | ||
local_backup_set_groups_[j].sets[k].get() = | ||
LocalBackupSet{tagCounter, ZeroEntry()}; | ||
tagCounter += 1; | ||
} | ||
} | ||
} | ||
|
||
void QueryServiceClient::FetchFullDB() { | ||
|
@@ -145,17 +148,16 @@ void QueryServiceClient::FetchFullDB() { | |
std::lock_guard<std::mutex> lock(hitMapMutex); | ||
hitMap[offset] = true; | ||
} | ||
DBEntryXorFromRaw(&primary_sets_[j].parity, | ||
&chunk[offset * DBEntryLength]); | ||
primary_sets_[j].parity.XorFromRaw(&chunk[offset * entry_size_]); | ||
} | ||
|
||
// update the parities for the backup hints | ||
for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) { | ||
const auto tmp = | ||
PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); | ||
const auto offset = tmp & (chunk_size_ - 1); | ||
DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct, | ||
&chunk[offset * DBEntryLength]); | ||
local_backup_sets_[j].parityAfterPuncture.XorFromRaw( | ||
&chunk[offset * entry_size_]); | ||
} | ||
}); | ||
} | ||
|
@@ -171,10 +173,10 @@ void QueryServiceClient::FetchFullDB() { | |
// empty. | ||
for (uint64_t j = 0; j < chunk_size_; j++) { | ||
if (!hitMap[j]) { | ||
std::array<uint64_t, DBEntryLength> entry_slice{}; | ||
std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength], | ||
DBEntryLength * sizeof(uint64_t)); | ||
const auto entry = DBEntryFromSlice(entry_slice); | ||
std::vector<uint8_t> entry_slice(entry_size_); | ||
std::memcpy(entry_slice.data(), &chunk[j * entry_size_], | ||
entry_size_ * sizeof(uint8_t)); | ||
const auto entry = DBEntry::DBEntryFromSlice(entry_slice); | ||
local_miss_elements_[j + (i * chunk_size_)] = entry; | ||
} | ||
} | ||
|
@@ -185,30 +187,30 @@ void QueryServiceClient::FetchFullDB() { | |
const auto tag = local_backup_set_groups_[i].sets[k].get().tag; | ||
const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); | ||
const auto offset = tmp & (chunk_size_ - 1); | ||
DBEntryXorFromRaw( | ||
&local_backup_set_groups_[i].sets[k].get().parityAfterPunct, | ||
&chunk[offset * DBEntryLength]); | ||
local_backup_set_groups_[i].sets[k].get().parityAfterPuncture.XorFromRaw( | ||
&chunk[offset * entry_size_]); | ||
} | ||
|
||
// store the replacement | ||
std::mt19937_64 rng(yacl::crypto::FastRandU64()); | ||
yacl::crypto::Prg<uint64_t> prg(yacl::crypto::SecureRandU64()); | ||
for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里没看太明白,这里应该对应论文Figrue-1中的,Offline 预处理阶段的第三个阶段的 论文的描述中, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前是在多线程中处理所有 |
||
// generate a random offset between 0 and ChunkSize - 1 | ||
const auto offset = rng() & (chunk_size_ - 1); | ||
const auto offset = prg() & (chunk_size_ - 1); | ||
local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; | ||
std::array<uint64_t, DBEntryLength> entry_slice{}; | ||
std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength], | ||
DBEntryLength * sizeof(uint64_t)); | ||
local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice); | ||
std::vector<uint8_t> entry_slice(entry_size_); | ||
std::memcpy(entry_slice.data(), &chunk[offset * entry_size_], | ||
entry_size_ * sizeof(uint8_t)); | ||
local_replacement_groups_[i].value[k] = | ||
DBEntry::DBEntryFromSlice(entry_slice); | ||
} | ||
} | ||
} | ||
|
||
void QueryServiceClient::SendDummySet() const { | ||
std::mt19937_64 rng(yacl::crypto::FastRandU64()); | ||
yacl::crypto::Prg<uint64_t> prg(yacl::crypto::SecureRandU64()); | ||
std::vector<uint64_t> randSet(set_size_); | ||
for (uint64_t i = 0; i < set_size_; i++) { | ||
randSet[i] = rng() % chunk_size_ + i * chunk_size_; | ||
randSet[i] = prg() % chunk_size_ + i * chunk_size_; | ||
} | ||
|
||
// send the random dummy set to the server | ||
|
@@ -249,7 +251,7 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { | |
} | ||
} | ||
|
||
DBEntry xVal = ZeroEntry(); | ||
DBEntry xVal = DBEntry::ZeroEntry(entry_size_); | ||
|
||
if (hitSetId == std::numeric_limits<uint64_t>::max()) { | ||
if (local_miss_elements_.find(x) == local_miss_elements_.end()) { | ||
|
@@ -301,9 +303,9 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { | |
const auto& parity = std::get<0>(parityQueryResponse); | ||
|
||
// recover the answer | ||
xVal = primary_sets_[hitSetId].parity; // the parity of the hit set | ||
DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set | ||
DBEntryXor(&xVal, &repVal); // xor the replacement value | ||
xVal = primary_sets_[hitSetId].parity; // the parity of the hit set | ||
xVal.XorFromRaw(parity.data()); // xor the parity of the edited set | ||
xVal.Xor(repVal); // xor the replacement value | ||
|
||
// update the local cache | ||
local_cache_[x] = xVal; | ||
|
@@ -319,9 +321,10 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { | |
primary_sets_[hitSetId].tag = | ||
local_backup_set_groups_[chunkId].sets[consumed].get().tag; | ||
// backup set doesn't XOR the chunk(x)-th element in preparation | ||
DBEntryXor( | ||
&xVal, | ||
&local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct); | ||
xVal.Xor(local_backup_set_groups_[chunkId] | ||
.sets[consumed] | ||
.get() | ||
.parityAfterPuncture); | ||
primary_sets_[hitSetId].parity = xVal; | ||
primary_sets_[hitSetId].isProgrammed = true; | ||
// for load balancing, the chunk(x)-th element differs from the one expanded | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,73 +2,64 @@ | |
|
||
#include <spdlog/spdlog.h> | ||
|
||
#include <cmath> | ||
#include <cstdint> | ||
#include <iostream> | ||
#include <thread> | ||
#include <utility> | ||
|
||
#include "experimental/pir/piano/serialize.h" | ||
#include "experimental/pir/piano/util.h" | ||
#include "yacl/crypto/rand/rand.h" | ||
#include "yacl/crypto/tools/prg.h" | ||
#include "yacl/link/context.h" | ||
|
||
namespace pir::piano { | ||
|
||
class LocalSet { | ||
public: | ||
uint32_t tag; // the tag of the set | ||
DBEntry parity; | ||
uint64_t | ||
programmedPoint; // identifier for the element replaced after refresh, | ||
// differing from those expanded by PRFEval | ||
bool isProgrammed; | ||
|
||
LocalSet(const uint32_t tag, const DBEntry& parity, | ||
const uint64_t programmed_point, const bool is_programmed) | ||
struct LocalSet { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LocalSet, LocalBackupSet, LocalBackupSetGroup, LocalReplacementGroup 这几个类型主要拿来干嘛的?请添加有效注释说明一下。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加注释说明 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,修改后请提交最新的commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已提交commit,请问我是否需要将代码从 psi/experimental/ 移动到 psi/experiment/ 目录下,我看已有实验性代码都在该目录下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要的,请移动到 |
||
LocalSet(const uint32_t tag, DBEntry parity, const uint64_t programmed_point, | ||
const bool is_programmed) | ||
: tag(tag), | ||
parity(parity), | ||
parity(std::move(parity)), | ||
programmedPoint(programmed_point), | ||
isProgrammed(is_programmed) {} | ||
|
||
uint32_t tag; // the tag of the set | ||
DBEntry parity; | ||
uint64_t programmedPoint; // identifier for the element replaced after | ||
// refresh differing from those expanded by PRFEval | ||
bool isProgrammed; | ||
}; | ||
|
||
class LocalBackupSet { | ||
public: | ||
uint32_t tag; | ||
DBEntry parityAfterPunct; | ||
struct LocalBackupSet { | ||
LocalBackupSet(const uint32_t tag, DBEntry parity_after_puncture) | ||
: tag(tag), parityAfterPuncture(std::move(parity_after_puncture)) {} | ||
|
||
LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct) | ||
: tag(tag), parityAfterPunct(parity_after_punct) {} | ||
uint32_t tag; | ||
DBEntry parityAfterPuncture; | ||
}; | ||
|
||
class LocalBackupSetGroup { | ||
public: | ||
uint64_t consumed; | ||
std::vector<std::reference_wrapper<LocalBackupSet>> sets; | ||
|
||
struct LocalBackupSetGroup { | ||
LocalBackupSetGroup( | ||
const uint64_t consumed, | ||
const std::vector<std::reference_wrapper<LocalBackupSet>>& sets) | ||
: consumed(consumed), sets(sets) {} | ||
}; | ||
|
||
class LocalReplacementGroup { | ||
public: | ||
uint64_t consumed; | ||
std::vector<uint64_t> indices; | ||
std::vector<DBEntry> value; | ||
std::vector<std::reference_wrapper<LocalBackupSet>> sets; | ||
}; | ||
|
||
struct LocalReplacementGroup { | ||
LocalReplacementGroup(const uint64_t consumed, | ||
const std::vector<uint64_t>& indices, | ||
const std::vector<DBEntry>& value) | ||
: consumed(consumed), indices(indices), value(value) {} | ||
|
||
uint64_t consumed; | ||
std::vector<uint64_t> indices; | ||
std::vector<DBEntry> value; | ||
}; | ||
|
||
class QueryServiceClient { | ||
public: | ||
static constexpr uint64_t FailureProbLog2 = 40; | ||
uint64_t totalQueryNum{}; | ||
|
||
QueryServiceClient(uint64_t db_size, uint64_t thread_num, | ||
QueryServiceClient(uint64_t entry_num, uint64_t thread_num, | ||
uint64_t entry_size, | ||
std::shared_ptr<yacl::link::Context> context); | ||
|
||
void Initialize(); | ||
|
@@ -78,24 +69,28 @@ class QueryServiceClient { | |
DBEntry OnlineSingleQuery(uint64_t x); | ||
std::vector<DBEntry> OnlineMultipleQueries( | ||
const std::vector<uint64_t>& queries); | ||
uint64_t getTotalQueryNumber() const { return total_query_num_; }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 方法名首字母统一用大写,请自己检查其他地方类似的问题,然后修改掉。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已检查并修改 |
||
|
||
private: | ||
uint64_t db_size_; | ||
static constexpr uint64_t kFailureProbLog2 = 40; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个const没必要定义在这里吧?一般放在一个文件最上方的位置 |
||
uint64_t total_query_num_{}; | ||
uint64_t entry_num_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 初始化 |
||
uint64_t thread_num_; | ||
std::shared_ptr<yacl::link::Context> context_; | ||
|
||
uint64_t chunk_size_{}; | ||
uint64_t set_size_{}; | ||
uint64_t entry_size_{}; | ||
uint64_t primary_set_num_{}; | ||
uint64_t backup_set_num_per_chunk_{}; | ||
uint64_t total_backup_set_num_{}; | ||
PrfKey master_key_{}; | ||
uint128_t master_key_{}; | ||
yacl::crypto::AES_KEY long_key_{}; | ||
|
||
std::vector<LocalSet> primary_sets_; | ||
std::vector<LocalBackupSet> local_backup_sets_; | ||
std::map<uint64_t, DBEntry> local_cache_; | ||
std::map<uint64_t, DBEntry> local_miss_elements_; | ||
std::unordered_map<uint64_t, DBEntry> local_cache_; | ||
std::unordered_map<uint64_t, DBEntry> local_miss_elements_; | ||
std::vector<LocalBackupSetGroup> local_backup_set_groups_; | ||
std::vector<LocalReplacementGroup> local_replacement_groups_; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thread_num是不是要做检验?如果是0咋办