Skip to content

Commit 84a1505

Browse files
PtxtArray serialisation
Co-authored-by: Hamish Hunt <hamishun@gmail.com>
1 parent 202b3e1 commit 84a1505

File tree

5 files changed

+595
-23
lines changed

5 files changed

+595
-23
lines changed

examples/tutorial/07_ckks_serialization.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,14 @@ int main(int argc, char* argv[])
7171
std::cout << "\n\n*** Public Key:\n";
7272
std::cout << public_key.writeToJSON().pretty() << std::endl;
7373

74-
// Create a Ptxt data object
74+
// Create a data object
7575
std::vector<long> data(context.getNSlots());
7676

7777
// Generate some data
7878
std::iota(data.begin(), data.end(), 0);
7979

80-
// Create a ptxt. Note that in this tutorial we make use of the
81-
// alternative ptxt API.
82-
helib::Ptxt<helib::CKKS> ptxt(context, data);
80+
// Create a PtxtArray.
81+
helib::PtxtArray ptxt(context, data);
8382

8483
// Print the ptxt to stdout
8584
std::cout << "\n\n*** Ptxt:\n";
@@ -89,7 +88,7 @@ int main(int argc, char* argv[])
8988
helib::Ctxt ctxt(public_key);
9089

9190
// Encrypt `data` into the ciphertext
92-
public_key.Encrypt(ctxt, ptxt);
91+
ptxt.encrypt(ctxt);
9392

9493
// Print the ctxt to stdout
9594
std::cout << "\n\n*** Ctxt:\n";

examples/tutorial/08_ckks_deserialization.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,23 @@ int main(int argc, char* argv[])
151151
// inspect the file.
152152
std::remove("pk.json");
153153

154-
// Create a Ptxt data object
154+
// Create a data object
155155
std::vector<long> data(context.getNSlots());
156156

157157
// Generate some data
158158
std::iota(data.begin(), data.end(), 0);
159159

160-
// Create a ptxt. Note that in this tutorial we make use of the
161-
// alternative ptxt API.
162-
helib::Ptxt<helib::CKKS> ptxt(context, data);
160+
// Create a ptxt.
161+
helib::PtxtArray ptxt(context, data);
163162

164163
std::ofstream outPtxtFile;
165164
outPtxtFile.open("ptxt.json", std::ios::out);
166165
if (outPtxtFile.is_open()) {
167166
// Write the ptxt to a file
168167
ptxt.writeToJSON(outPtxtFile);
168+
// Alternatively one can use
169+
// outPtxtFile << ptxt;
170+
169171
// Close the ofstream
170172
outPtxtFile.close();
171173
} else {
@@ -176,11 +178,11 @@ int main(int argc, char* argv[])
176178
inPtxtFile.open("ptxt.json");
177179
if (inPtxtFile.is_open()) {
178180
// Read in the ptxt from the file
179-
helib::Ptxt<helib::CKKS> deserializedPtxt =
180-
helib::Ptxt<helib::CKKS>::readFromJSON(inPtxtFile, context);
181+
helib::PtxtArray deserializedPtxt =
182+
helib::PtxtArray::readFromJSON(inPtxtFile, context);
181183
// Note there are alternative methods for deserialization of Ptxt objects.
182184
// After initialization
183-
// helib::Ptxt<helib::CKKS> deserializedPtxt(publicKey);
185+
// helib::PtxtArray deserializedPtxt(publicKey);
184186
// One can write
185187
// inPtxtFile >> deserializedPtxt;
186188
// Or alternatively
@@ -199,7 +201,7 @@ int main(int argc, char* argv[])
199201
helib::Ctxt ctxt(publicKey);
200202

201203
// Encrypt `data` into the ciphertext
202-
publicKey.Encrypt(ctxt, ptxt);
204+
ptxt.encrypt(ctxt);
203205

204206
std::ofstream outCtxtFile;
205207
outCtxtFile.open("ctxt.json", std::ios::out);

include/helib/EncryptedArray.h

+26-5
Original file line numberDiff line numberDiff line change
@@ -2144,6 +2144,12 @@ void runningSums(const EncryptedArray& ea, PlaintextArray& pa);
21442144
class PtxtArray
21452145
{
21462146
public:
2147+
/**
2148+
* @brief Class label to be added to JSON serialization as object type
2149+
* information.
2150+
*/
2151+
static constexpr std::string_view typeName = "PtxtArray";
2152+
21472153
// These two data fields should really be private, but there are
21482154
// a lot of internal functions that need to access them
21492155
const EncryptedArray& ea;
@@ -2189,6 +2195,8 @@ class PtxtArray
21892195
const EncryptedArray& getView() const { return ea; }
21902196
const EncryptedArray& getEA() const { return ea; }
21912197

2198+
long size() const { return ea.size(); }
2199+
21922200
// direct encode, encrypt, and decrypt methods
21932201
void encode(EncodedPtxt& eptxt,
21942202
double mag = -1,
@@ -2256,6 +2264,7 @@ class PtxtArray
22562264
void random() { helib::random(ea, pa); }
22572265

22582266
//======== load ========
2267+
// Puts vector or scalar data into a PtxtArray
22592268

22602269
void load(const std::vector<int>& array)
22612270
{
@@ -2349,6 +2358,7 @@ class PtxtArray
23492358
}
23502359

23512360
//============== store ============
2361+
// Puts data into a std::vector aka `unload`
23522362

23532363
void store(std::vector<long>& array) const { decode(ea, array, pa); }
23542364

@@ -2362,12 +2372,23 @@ class PtxtArray
23622372

23632373
// this is here for consistency with Ctxt class
23642374
void negate() { helib::negate(ea, pa); }
2365-
};
23662375

2367-
inline std::ostream& operator<<(std::ostream& s, const PtxtArray& a)
2368-
{
2369-
return s << a.pa;
2370-
}
2376+
void writeToJSON(std::ostream& os) const;
2377+
2378+
JsonWrapper writeToJSON() const;
2379+
2380+
static PtxtArray readFromJSON(std::istream& is, const Context& context);
2381+
2382+
static PtxtArray readFromJSON(const JsonWrapper& jw, const Context& context);
2383+
2384+
void readJSON(std::istream& is);
2385+
2386+
void readJSON(const JsonWrapper& jw);
2387+
2388+
friend std::istream& operator>>(std::istream& is, PtxtArray& pa);
2389+
2390+
friend std::ostream& operator<<(std::ostream& os, const PtxtArray& pa);
2391+
};
23712392

23722393
inline void rotate(PtxtArray& a, long k) { rotate(a.ea, a.pa, k); }
23732394

src/EncryptedArray.cpp

+143
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <helib/norms.h>
2020
#include <helib/exceptions.h>
2121

22+
#include "io.h"
23+
2224
namespace helib {
2325

2426
EncryptedArrayBase* buildEncryptedArray(const Context& context,
@@ -547,6 +549,147 @@ void EncryptedArrayDerived<type>::initNormalBasisMatrix() const
547549
} while (0);
548550
}
549551

552+
// PtxtArray member functions
553+
554+
void PtxtArray::writeToJSON(std::ostream& os) const
555+
{
556+
executeRedirectJsonError<void>([&]() { os << this->writeToJSON(); });
557+
}
558+
559+
JsonWrapper PtxtArray::writeToJSON() const
560+
{
561+
auto body = [this]() {
562+
json jslots;
563+
564+
if (ea.isCKKS()) {
565+
// When it is CKKS
566+
std::vector<std::complex<double>> data;
567+
store(data);
568+
jslots = data;
569+
} else {
570+
// When is it BGV
571+
std::vector<NTL::ZZX> data;
572+
store(data);
573+
std::vector<std::vector<long>> slots(data.size());
574+
for (std::size_t i = 0; i < data.size(); ++i) {
575+
long deg = NTL::deg(data[i]);
576+
if (deg == -1) {
577+
slots[i].emplace_back(0);
578+
}
579+
for (long j = 0; j <= deg; ++j) {
580+
slots[i].emplace_back(NTL::conv<long>(data[i][j]));
581+
}
582+
}
583+
jslots = slots;
584+
}
585+
586+
json j{{"scheme", (ea.isCKKS() ? "CKKS" : "BGV")}, {"slots", jslots}};
587+
588+
return wrap(toTypedJson<PtxtArray>(j));
589+
};
590+
591+
return executeRedirectJsonError<JsonWrapper>(body);
592+
}
593+
594+
PtxtArray PtxtArray::readFromJSON(std::istream& is, const Context& context)
595+
{
596+
PtxtArray ret{context};
597+
ret.readJSON(is);
598+
return ret;
599+
}
600+
601+
PtxtArray PtxtArray::readFromJSON(const JsonWrapper& tjw,
602+
const Context& context)
603+
{
604+
PtxtArray ret{context};
605+
ret.readJSON(tjw);
606+
return ret;
607+
}
608+
609+
void PtxtArray::readJSON(std::istream& is)
610+
{
611+
executeRedirectJsonError<void>([&]() {
612+
json j;
613+
is >> j;
614+
this->readJSON(wrap(j));
615+
});
616+
}
617+
618+
void PtxtArray::readJSON(const JsonWrapper& tjw)
619+
{
620+
auto body = [&]() {
621+
json tj = unwrap(tjw);
622+
json jslots;
623+
// if the input is just an array short-circuit to slot deserialization
624+
// (assuming there is no type-header).
625+
if (tj.is_array()) {
626+
jslots = tj;
627+
628+
} else {
629+
json j = fromTypedJson<PtxtArray>(tj);
630+
631+
std::string expected_scheme{j.at("scheme").get<std::string>()};
632+
assertTrue<IOError>(
633+
expected_scheme == (ea.isCKKS() ? "CKKS" : "BGV"),
634+
"Scheme mismatch in deserialization.\nExpected: " + expected_scheme +
635+
", actual: " + std::string(ea.isCKKS() ? "CKKS" : "BGV") + ".");
636+
637+
jslots = j.at("slots");
638+
639+
if (!jslots.is_array()) {
640+
throw IOError("Slot content is not a JSON array");
641+
}
642+
}
643+
644+
if (static_cast<long>(jslots.size()) > this->getEA().size()) {
645+
std::stringstream err_msg;
646+
err_msg << "Cannot deserialize to PtxtArray: not enough slots. "
647+
<< "Trying to deserialize " << jslots.size() << " elements. "
648+
<< "Got " << this->getEA().size() << " slots.";
649+
throw IOError(err_msg.str());
650+
}
651+
652+
if (ea.isCKKS()) {
653+
// Scheme is CKKS
654+
this->load(jslots.get<std::vector<std::complex<double>>>());
655+
} else {
656+
// Scheme is BGV
657+
auto json2data = [](const json& jslots) {
658+
std::vector<NTL::ZZX> data;
659+
data.reserve(jslots.size());
660+
for (const auto& jslot : jslots) {
661+
NTL::ZZX slot;
662+
if (jslot.is_array()) {
663+
for (std::size_t i = 0; i < jslot.size(); ++i) {
664+
NTL::SetCoeff(slot, i, static_cast<long>(jslot[i]));
665+
}
666+
} else {
667+
// Slot is a single number
668+
slot = static_cast<long>(jslot);
669+
}
670+
data.emplace_back(slot);
671+
}
672+
return data;
673+
};
674+
this->load(json2data(jslots));
675+
}
676+
};
677+
678+
executeRedirectJsonError<void>(body);
679+
}
680+
681+
std::istream& operator>>(std::istream& is, PtxtArray& pa)
682+
{
683+
pa.readJSON(is);
684+
return is;
685+
}
686+
687+
std::ostream& operator<<(std::ostream& os, const PtxtArray& pa)
688+
{
689+
pa.writeToJSON(os);
690+
return os;
691+
}
692+
550693
// Other functions...
551694

552695
void runningSums(const EncryptedArray& ea, Ctxt& ctxt)

0 commit comments

Comments
 (0)