LCOV - code coverage report
Current view: top level - src/bls - bls_worker.cpp (source / functions) Hit Total Coverage
Test: total_coverage.info Lines: 312 532 58.6 %
Date: 2026-06-25 07:23:43 Functions: 128 254 50.4 %

          Line data    Source code
       1             : // Copyright (c) 2018-2025 The Dash Core developers
       2             : // Distributed under the MIT/X11 software license, see the accompanying
       3             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4             : 
       5             : #include <bls/bls_worker.h>
       6             : #include <hash.h>
       7             : 
       8             : #include <util/system.h>
       9             : 
      10             : #include <memory>
      11             : #include <ranges>
      12             : #include <utility>
      13             : 
      14             : template <typename T>
      15       12938 : bool VerifyVectorHelper(Span<T> vec)
      16             : {
      17       12938 :     std::set<uint256> set;
      18       52979 :     for (const auto& item : vec) {
      19       40043 :         if (!item.IsValid())
      20           0 :             return false;
      21             :         // check duplicates
      22       40040 :         if (!set.emplace(item.GetHash()).second) {
      23           0 :             return false;
      24             :         }
      25             :     }
      26       12936 :     return true;
      27       12942 : }
      28             : 
      29             : // Creates a doneCallback and a future. The doneCallback simply finishes the future
      30             : template <typename T>
      31       10326 : std::pair<std::function<void(const T&)>, std::future<T> > BuildFutureDoneCallback()
      32             : {
      33       10326 :     auto p = std::make_shared<std::promise<T> >();
      34       20652 :     std::function<void(const T&)> f = [p](const T& v) {
      35       10326 :         p->set_value(v);
      36       10326 :     };
      37       10326 :     return std::make_pair(std::move(f), p->get_future());
      38       10326 : }
      39             : template <typename T>
      40           0 : std::pair<std::function<void(T)>, std::future<T> > BuildFutureDoneCallback2()
      41             : {
      42           0 :     auto p = std::make_shared<std::promise<T> >();
      43           0 :     std::function<void(const T&)> f = [p](T v) {
      44           0 :         p->set_value(v);
      45           0 :     };
      46           0 :     return std::make_pair(std::move(f), p->get_future());
      47           0 : }
      48             : 
      49             : 
      50             : /////
      51             : 
      52        6126 : CBLSWorker::CBLSWorker() = default;
      53             : 
      54        6126 : CBLSWorker::~CBLSWorker()
      55        3063 : {
      56        3063 :     Stop();
      57        6126 : }
      58             : 
      59        3063 : void CBLSWorker::Start(int16_t worker_count)
      60             : {
      61        3063 :     assert(worker_count > 0);
      62        3063 :     workerPool.resize(worker_count);
      63        3063 :     RenameThreadPool(workerPool, "bls-work");
      64        3063 : }
      65             : 
      66        6126 : void CBLSWorker::Stop()
      67             : {
      68        6126 :     workerPool.clear_queue();
      69        6126 :     workerPool.stop(true);
      70        6126 : }
      71             : 
      72        2604 : bool CBLSWorker::GenerateContributions(int quorumThreshold, Span<CBLSId> ids, BLSVerificationVectorPtr& vvecRet, std::vector<CBLSSecretKey>& skSharesRet)
      73             : {
      74        2604 :     auto svec = std::vector<CBLSSecretKey>((size_t)quorumThreshold);
      75        2604 :     vvecRet = std::make_shared<std::vector<CBLSPublicKey>>((size_t)quorumThreshold);
      76        2604 :     skSharesRet.resize(ids.size());
      77             : 
      78        9733 :     for (int i = 0; i < quorumThreshold; i++) {
      79        7129 :         svec[i].MakeNewKey();
      80        7129 :     }
      81        2604 :     size_t batchSize = 8;
      82        2604 :     std::vector<std::future<bool>> futures;
      83        2604 :     futures.reserve((quorumThreshold / batchSize + ids.size() / batchSize) + 2);
      84             : 
      85        5208 :     for (size_t i = 0; i < size_t(quorumThreshold); i += batchSize) {
      86        2604 :         size_t start = i;
      87        2604 :         size_t count = std::min(batchSize, quorumThreshold - start);
      88        5207 :         auto f = [&, start, count](int threadId) {
      89        9732 :             for (size_t j = start; j < start + count; j++) {
      90        7129 :                 (*vvecRet)[j] = svec[j].GetPublicKey();
      91        7129 :             }
      92        2603 :             return true;
      93             :         };
      94        2604 :         futures.emplace_back(workerPool.push(f));
      95        2604 :     }
      96             : 
      97        5208 :     for (size_t i = 0; i < ids.size(); i += batchSize) {
      98        2604 :         size_t start = i;
      99        2604 :         size_t count = std::min(batchSize, ids.size() - start);
     100        5208 :         auto f = [&, start, count](int threadId) {
     101       12390 :             for (size_t j = start; j < start + count; j++) {
     102        9786 :                 if (!skSharesRet[j].SecretKeyShare(svec, ids[j])) {
     103           0 :                     return false;
     104             :                 }
     105        9786 :             }
     106        2604 :             return true;
     107        2604 :         };
     108        2604 :         futures.emplace_back(workerPool.push(f));
     109        2604 :     }
     110        7812 :     return std::ranges::all_of(futures, [](auto& f) { return f.get(); });
     111        2604 : }
     112             : 
     113             : // aggregates a single vector of BLS objects in parallel
     114             : // the input vector is split into batches and each batch is aggregated in parallel
     115             : // when enough batches are finished to form a new batch, the new batch is queued for further parallel aggregation
     116             : // when no more batches can be created from finished batch results, the final aggregated is created and the doneCallback
     117             : // called.
     118             : // The Aggregator object needs to be created on the heap, and it will delete itself after calling the doneCallback
     119             : // The input vector is not copied into the Aggregator but instead a vector of pointers to the original entries from the
     120             : // input vector is stored. This means that the input vector must stay alive for the whole lifetime of the Aggregator
     121             : template <typename T>
     122             : struct Aggregator : public std::enable_shared_from_this<Aggregator<T>> {
     123       23768 :     const size_t BATCH_SIZE{16};
     124             :     std::shared_ptr<std::vector<const T*> > inputVec;
     125             : 
     126             :     bool parallel;
     127             :     ctpl::thread_pool& workerPool;
     128             : 
     129             :     std::mutex m;
     130             :     // items in the queue are all intermediate aggregation results of finished batches.
     131             :     // The intermediate results must be deleted by us again (which we do in SyncAggregateAndPushAggQueue)
     132             :     ctpl::detail::Queue<T*> aggQueue;
     133       23768 :     std::atomic<size_t> aggQueueSize{0};
     134             : 
     135             :     // keeps track of currently queued/in-progress batches. If it reaches 0, we are done
     136       23768 :     std::atomic<size_t> waitCount{0};
     137             : 
     138             :     using DoneCallback = std::function<void(const T& agg)>;
     139             :     DoneCallback doneCallback;
     140             : 
     141             :     // TP can either be a pointer or a reference
     142             :     template <typename TP>
     143       47536 :     Aggregator(Span<TP> _inputSpan, bool _parallel,
     144             :                ctpl::thread_pool& _workerPool,
     145             :                DoneCallback _doneCallback) :
     146       23768 :             inputVec(std::make_shared<std::vector<const T*>>(_inputSpan.size())),
     147       23768 :             parallel(_parallel),
     148       23768 :             workerPool(_workerPool),
     149       23768 :             doneCallback(std::move(_doneCallback))
     150       23768 :     {
     151      112636 :         for (size_t i = 0; i < _inputSpan.size(); i++) {
     152       88868 :             (*inputVec)[i] = pointer(_inputSpan[i]);
     153       88868 :         }
     154       47536 :     }
     155             : 
     156       23338 :     const T* pointer(const T& v) { return &v; }
     157       65524 :     const T* pointer(const T* v) { return v; }
     158             : 
     159             :     // Starts aggregation.
     160             :     // If parallel=true, then this will return fast, otherwise this will block until aggregation is done
     161       23770 :     void Start()
     162             :     {
     163       23770 :         size_t batchCount = (inputVec->size() + BATCH_SIZE - 1) / BATCH_SIZE;
     164             : 
     165       23770 :         if (!parallel) {
     166           0 :             if (inputVec->size() == 1) {
     167           0 :                 doneCallback(*(*inputVec)[0]);
     168           0 :             } else {
     169           0 :                 doneCallback(SyncAggregate(Span{*inputVec}, 0, inputVec->size()));
     170             :             }
     171           0 :             return;
     172             :         }
     173             : 
     174       23770 :         if (batchCount == 1) {
     175             :             // just a single batch of work, take a shortcut.
     176       23770 :             auto self(this->shared_from_this());
     177       47474 :             PushWork([this, self](int threadId) {
     178       23704 :                 if (inputVec->size() == 1) {
     179         413 :                     doneCallback(*(*inputVec)[0]);
     180         413 :                 } else {
     181       23291 :                     doneCallback(SyncAggregate(Span{*inputVec}, 0, inputVec->size()));
     182             :                 }
     183       23704 :             });
     184             :             return;
     185       23770 :         }
     186             : 
     187             :         // increment wait counter as otherwise the first finished async aggregation might signal that we're done
     188           0 :         IncWait();
     189           0 :         for (size_t i = 0; i < batchCount; i++) {
     190           0 :             size_t start = i * BATCH_SIZE;
     191           0 :             size_t count = std::min(BATCH_SIZE, inputVec->size() - start);
     192           0 :             AsyncAggregateAndPushAggQueue(inputVec, start, count, false);
     193           0 :         }
     194             :         // this will decrement the wait counter and in most cases NOT finish, as async work is still in progress
     195           0 :         CheckDone();
     196       23770 :     }
     197             : 
     198           0 :     void IncWait()
     199             :     {
     200           0 :         ++waitCount;
     201           0 :     }
     202             : 
     203           0 :     void CheckDone()
     204             :     {
     205           0 :         if (--waitCount == 0) {
     206           0 :             Finish();
     207           0 :         }
     208           0 :     }
     209             : 
     210           0 :     void Finish()
     211             :     {
     212             :         // All async work is done, but we might have items in the aggQueue which are the results of the async
     213             :         // work. This is the case when these did not add up to a new batch. In this case, we have to aggregate
     214             :         // the items into the final result
     215             : 
     216           0 :         std::vector<T*> rem(aggQueueSize);
     217           0 :         for (size_t i = 0; i < rem.size(); i++) {
     218           0 :             T* p = nullptr;
     219           0 :             bool s = aggQueue.pop(p);
     220           0 :             assert(s);
     221           0 :             rem[i] = p;
     222           0 :         }
     223             : 
     224           0 :         T r;
     225           0 :         if (rem.size() == 1) {
     226             :             // just one intermediate result, which is actually the final result
     227           0 :             r = *rem[0];
     228           0 :         } else {
     229             :             // multiple intermediate results left which did not add up to a new batch. aggregate them now
     230           0 :             r = SyncAggregate(Span{rem}, 0, rem.size());
     231             :         }
     232             : 
     233             :         // all items which are left in the queue are intermediate results, so we must delete them
     234           0 :         for (size_t i = 0; i < rem.size(); i++) {
     235           0 :             delete rem[i];
     236           0 :         }
     237           0 :         doneCallback(r);
     238           0 :     }
     239             : 
     240           0 :     void AsyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*>>& vec, size_t start, size_t count, bool del)
     241             :     {
     242           0 :         IncWait();
     243           0 :         auto self(this->shared_from_this());
     244           0 :         PushWork([self, vec, start, count, del](int threadId){
     245           0 :             self->SyncAggregateAndPushAggQueue(vec, start, count, del);
     246           0 :         });
     247           0 :     }
     248             : 
     249           0 :     void SyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*>>& vec, size_t start, size_t count, bool del)
     250             :     {
     251             :         // aggregate vec and push the intermediate result onto the work queue
     252           0 :         PushAggQueue(SyncAggregate(Span{*vec}, start, count));
     253           0 :         if (del) {
     254           0 :             for (size_t i = 0; i < count; i++) {
     255           0 :                 delete (*vec)[start + i];
     256           0 :             }
     257           0 :         }
     258           0 :         CheckDone();
     259           0 :     }
     260             : 
     261           0 :     void PushAggQueue(const T& v)
     262             :     {
     263           0 :         auto copyT = new T(v);
     264             :         try {
     265           0 :             aggQueue.push(copyT);
     266           0 :         } catch (...) {
     267           0 :             delete copyT;
     268           0 :             throw;
     269           0 :         }
     270             : 
     271           0 :         if (++aggQueueSize >= BATCH_SIZE) {
     272             :             // we've collected enough intermediate results to form a new batch.
     273           0 :             std::shared_ptr<std::vector<const T*> > newBatch;
     274             :             {
     275           0 :                 std::unique_lock<std::mutex> l(m);
     276           0 :                 if (aggQueueSize < BATCH_SIZE) {
     277             :                     // some other worker thread grabbed this batch
     278           0 :                     return;
     279             :                 }
     280           0 :                 newBatch = std::make_shared<std::vector<const T*>>(BATCH_SIZE);
     281             :                 // collect items for new batch
     282           0 :                 for (size_t i = 0; i < BATCH_SIZE; i++) {
     283           0 :                     T* p = nullptr;
     284           0 :                     bool s = aggQueue.pop(p);
     285           0 :                     assert(s);
     286           0 :                     (*newBatch)[i] = p;
     287           0 :                 }
     288           0 :                 aggQueueSize -= BATCH_SIZE;
     289           0 :             }
     290             : 
     291             :             // push new batch to work queue. del=true this time as these items are intermediate results and need to be deleted
     292             :             // after aggregation is done
     293           0 :             AsyncAggregateAndPushAggQueue(newBatch, 0, newBatch->size(), true);
     294           0 :         }
     295           0 :     }
     296             : 
     297             :     template <typename TP>
     298       23253 :     T SyncAggregate(Span<TP> vec, size_t start, size_t count)
     299             :     {
     300       23253 :         T result = *vec[start];
     301       87951 :         for (size_t j = 1; j < count; j++) {
     302       64698 :             result.AggregateInsecure(*vec[start + j]);
     303       64698 :         }
     304       23253 :         return result;
     305        6267 :     }
     306             : 
     307             :     template <typename Callable>
     308       23768 :     void PushWork(Callable&& f)
     309             :     {
     310       23768 :         workerPool.push(f);
     311       23768 :     }
     312             : };
     313             : 
     314             : // Aggregates multiple input vectors into a single output vector
     315             : // Inputs are in the following form:
     316             : //   [
     317             : //     [a1, b1, c1, d1],
     318             : //     [a2, b2, c2, d2],
     319             : //     [a3, b3, c3, d3],
     320             : //     [a4, b4, c4, d4],
     321             : //   ]
     322             : // The result is in the following form:
     323             : //   [ a1+a2+a3+a4, b1+b2+b3+b4, c1+c2+c3+c4, d1+d2+d3+d4]
     324             : // Same rules for the input vectors apply to the VectorAggregator as for the Aggregator (they must stay alive)
     325             : template <typename T>
     326             : struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T>> {
     327             :     using AggregatorType = Aggregator<T>;
     328             :     using VectorType = std::vector<T>;
     329             :     using VectorPtrType = std::shared_ptr<VectorType>;
     330             :     using VectorVectorType = Span<VectorPtrType>;
     331             :     using DoneCallback = std::function<void(const VectorPtrType& agg)>;
     332             :     DoneCallback doneCallback;
     333             : 
     334             :     VectorVectorType vecs;
     335             :     bool parallel;
     336             :     ctpl::thread_pool& workerPool;
     337             : 
     338        6378 :     std::atomic<size_t> doneCount{0};
     339             : 
     340             :     VectorPtrType result;
     341             :     size_t vecSize;
     342             : 
     343       19134 :     VectorAggregator(VectorVectorType _vecs,
     344             :                      bool _parallel, ctpl::thread_pool& _workerPool,
     345             :                      DoneCallback _doneCallback) :
     346        6378 :             doneCallback(std::move(_doneCallback)),
     347        6378 :             vecs(_vecs),
     348        6378 :             parallel(_parallel),
     349        6378 :             workerPool(_workerPool)
     350        6378 :     {
     351        6378 :         assert(!vecs.empty());
     352        6378 :         vecSize = vecs[0]->size();
     353        6378 :         result = std::make_shared<VectorType>(vecSize);
     354       12756 :     }
     355             : 
     356        6378 :     void Start()
     357             :     {
     358       23763 :         for (size_t i = 0; i < vecSize; i++) {
     359       17385 :             std::vector<const T*> tmp(vecs.size());
     360       82915 :             for (size_t j = 0; j < vecs.size(); j++) {
     361       65530 :                 tmp[j] = &(*vecs[j])[i];
     362       65530 :             }
     363             : 
     364       17385 :             auto self(this->shared_from_this());
     365       34448 :             auto aggregator = std::make_shared<AggregatorType>(Span{tmp}, parallel, workerPool, [self, i](const T& agg) {self->CheckDone(agg, i);});
     366       17385 :             aggregator->Start();
     367       17385 :         }
     368        6378 :     }
     369             : 
     370       16687 :     void CheckDone(const T& agg, size_t idx)
     371             :     {
     372       16687 :         (*result)[idx] = agg;
     373       16687 :         if (++doneCount == vecSize) {
     374        6378 :             doneCallback(result);
     375        6378 :         }
     376       16687 :     }
     377             : };
     378             : 
     379             : // See comment of AsyncVerifyContributionShares for a description on what this does
     380             : // Same rules as in Aggregator apply for the inputs
     381             : struct ContributionVerifier : public std::enable_shared_from_this<ContributionVerifier> {
     382        2438 :     struct BatchState {
     383             :         size_t start;
     384             :         size_t count;
     385             : 
     386             :         BLSVerificationVectorPtr vvec;
     387             :         CBLSSecretKey skShare;
     388             : 
     389             :         // starts with 0 and is incremented if either vvec or skShare aggregation finishes. If it reaches 2, we know
     390             :         // that aggregation for this batch is fully done. We can then start verification.
     391             :         std::unique_ptr<std::atomic<int> > aggDone;
     392             : 
     393             :         // we can't directly update a vector<bool> in parallel
     394             :         // as vector<bool> is not thread safe (uses bitsets internally)
     395             :         // so we must use vector<char> temporarily and concatenate/convert
     396             :         // each batch result into a final vector<bool>
     397             :         std::vector<char> verifyResults;
     398             :     };
     399             : 
     400             :     CBLSId forId;
     401             :     Span<BLSVerificationVectorPtr> vvecs;
     402             :     Span<CBLSSecretKey> skShares;
     403             :     size_t batchSize;
     404             :     bool parallel;
     405             :     bool aggregated;
     406             : 
     407             :     ctpl::thread_pool& workerPool;
     408             : 
     409        2438 :     size_t batchCount{1};
     410             :     size_t verifyCount;
     411             : 
     412             :     std::vector<BatchState> batchStates;
     413        2438 :     std::atomic<size_t> verifyDoneCount{0};
     414             :     std::function<void(const std::vector<bool>&)> doneCallback;
     415             : 
     416        7314 :     ContributionVerifier(CBLSId _forId, Span<BLSVerificationVectorPtr> _vvecs,
     417             :                          Span<CBLSSecretKey> _skShares, size_t _batchSize,
     418             :                          bool _parallel, bool _aggregated, ctpl::thread_pool& _workerPool,
     419             :                          std::function<void(const std::vector<bool>&)> _doneCallback) :
     420        2438 :         forId(std::move(_forId)),
     421        2438 :         vvecs(_vvecs),
     422        2438 :         skShares(_skShares),
     423        2438 :         batchSize(_batchSize),
     424        2438 :         parallel(_parallel),
     425        2438 :         aggregated(_aggregated),
     426        2438 :         workerPool(_workerPool),
     427        2438 :         verifyCount(_vvecs.size()),
     428        2438 :         doneCallback(std::move(_doneCallback))
     429        2438 :     {
     430        4876 :     }
     431             : 
     432        2438 :     void Start()
     433             :     {
     434        2438 :         if (!aggregated) {
     435             :             // treat all inputs as one large batch
     436           0 :             batchSize = vvecs.size();
     437           0 :         } else {
     438        2438 :             batchCount = (vvecs.size() + batchSize - 1) / batchSize;
     439             :         }
     440             : 
     441        2438 :         batchStates.resize(batchCount);
     442        4876 :         for (size_t i = 0; i < batchCount; i++) {
     443        2438 :             auto& batchState = batchStates[i];
     444             : 
     445        2438 :             batchState.aggDone = std::make_unique<std::atomic<int>>(0);
     446        2438 :             batchState.start = i * batchSize;
     447        2438 :             batchState.count = std::min(batchSize, vvecs.size() - batchState.start);
     448        2438 :             batchState.verifyResults.assign(batchState.count, 0);
     449        2438 :         }
     450             : 
     451        2438 :         if (aggregated) {
     452        2438 :             size_t batchCount2 = batchCount; // 'this' might get deleted while we're still looping
     453        4876 :             for (size_t i = 0; i < batchCount2; i++) {
     454        2438 :                 AsyncAggregate(i);
     455        2438 :             }
     456        2438 :         } else {
     457             :             // treat all inputs as a single batch and verify one-by-one
     458           0 :             AsyncVerifyBatchOneByOne(0);
     459             :         }
     460        2438 :     }
     461             : 
     462        2438 :     void Finish()
     463             :     {
     464        2438 :         size_t batchIdx = 0;
     465        2438 :         std::vector<bool> result(vvecs.size());
     466        4876 :         for (size_t i = 0; i < vvecs.size(); i += batchSize) {
     467        2438 :             const auto& batchState = batchStates[batchIdx++];
     468       11087 :             for (size_t j = 0; j < batchState.count; j++) {
     469        8649 :                 result[batchState.start + j] = batchState.verifyResults[j] != 0;
     470        8649 :             }
     471        2438 :         }
     472        2438 :         doneCallback(result);
     473        2438 :     }
     474             : 
     475        2438 :     void AsyncAggregate(size_t batchIdx)
     476             :     {
     477        2438 :         auto& batchState = batchStates[batchIdx];
     478             : 
     479             :         // aggregate vvecs and skShares of batch in parallel
     480        2438 :         auto self(this->shared_from_this());
     481        4876 :         auto vvecAgg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const BLSVerificationVectorPtr& vvec) {HandleAggVvecDone(batchIdx, vvec);});
     482        4876 :         auto skShareAgg = std::make_shared<Aggregator<CBLSSecretKey>>(Span{skShares}.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const CBLSSecretKey& skShare) {HandleAggSkShareDone(batchIdx, skShare);});
     483             : 
     484        2438 :         vvecAgg->Start();
     485        2438 :         skShareAgg->Start();
     486        2438 :     }
     487             : 
     488        2438 :     void HandleAggVvecDone(size_t batchIdx, const BLSVerificationVectorPtr& vvec)
     489             :     {
     490        2438 :         auto& batchState = batchStates[batchIdx];
     491        2438 :         batchState.vvec = vvec;
     492        2438 :         if (++(*batchState.aggDone) == 2) {
     493         427 :             HandleAggDone(batchIdx);
     494         427 :         }
     495        2438 :     }
     496        2438 :     void HandleAggSkShareDone(size_t batchIdx, const CBLSSecretKey& skShare)
     497             :     {
     498        2438 :         auto& batchState = batchStates[batchIdx];
     499        2438 :         batchState.skShare = skShare;
     500        2438 :         if (++(*batchState.aggDone) == 2) {
     501        2009 :             HandleAggDone(batchIdx);
     502        2009 :         }
     503        2438 :     }
     504             : 
     505        2462 :     void HandleVerifyDone(size_t count)
     506             :     {
     507        2462 :         size_t c = verifyDoneCount += count;
     508        2462 :         if (c == verifyCount) {
     509        2438 :             Finish();
     510        2438 :         }
     511        2462 :     }
     512             : 
     513        2438 :     void HandleAggDone(size_t batchIdx)
     514             :     {
     515        2438 :         auto& batchState = batchStates[batchIdx];
     516             : 
     517        2438 :         if (batchState.vvec == nullptr || batchState.vvec->empty() || !batchState.skShare.IsValid()) {
     518             :             // something went wrong while aggregating and there is nothing we can do now except mark the whole batch as failed
     519             :             // this can only happen if inputs were invalid in some way
     520           2 :             batchState.verifyResults.assign(batchState.count, 0);
     521           2 :             HandleVerifyDone(batchState.count);
     522           2 :             return;
     523             :         }
     524             : 
     525        2436 :         AsyncAggregatedVerifyBatch(batchIdx);
     526        2436 :     }
     527             : 
     528        2438 :     void AsyncAggregatedVerifyBatch(size_t batchIdx)
     529             :     {
     530        2438 :         auto self(this->shared_from_this());
     531        4876 :         auto f = [this, self, batchIdx](int threadId) {
     532        2438 :             auto& batchState = batchStates[batchIdx];
     533        2438 :             bool result = Verify(batchState.vvec, batchState.skShare);
     534        2438 :             if (result) {
     535             :                 // whole batch is valid
     536        2426 :                 batchState.verifyResults.assign(batchState.count, 1);
     537        2426 :                 HandleVerifyDone(batchState.count);
     538        2426 :             } else {
     539             :                 // at least one entry in the batch is invalid, revert to per-contribution verification (but parallelized)
     540          12 :                 AsyncVerifyBatchOneByOne(batchIdx);
     541             :             }
     542        2438 :         };
     543        2438 :         PushOrDoWork(std::move(f));
     544        2438 :     }
     545             : 
     546          12 :     void AsyncVerifyBatchOneByOne(size_t batchIdx)
     547             :     {
     548          12 :         size_t count = batchStates[batchIdx].count;
     549          12 :         batchStates[batchIdx].verifyResults.assign(count, 0);
     550          48 :         for (size_t i = 0; i < count; i++) {
     551          36 :             auto self(this->shared_from_this());
     552          71 :             auto f = [this, self, i, batchIdx](int threadId) {
     553          35 :                 auto& batchState = batchStates[batchIdx];
     554          35 :                 batchState.verifyResults[i] = Verify(vvecs[batchState.start + i], skShares[batchState.start + i]);
     555          35 :                 HandleVerifyDone(1);
     556          35 :             };
     557          36 :             PushOrDoWork(std::move(f));
     558          36 :         }
     559          12 :     }
     560             : 
     561        2473 :     bool Verify(const BLSVerificationVectorPtr& vvec, const CBLSSecretKey& skShare) const
     562             :     {
     563        2473 :         CBLSPublicKey pk1;
     564        2473 :         if (!pk1.PublicKeyShare(*vvec, forId)) {
     565           0 :             return false;
     566             :         }
     567             : 
     568        2473 :         CBLSPublicKey pk2 = skShare.GetPublicKey();
     569        2473 :         return pk1 == pk2;
     570        2473 :     }
     571             : 
     572             :     template <typename Callable>
     573        2474 :     void PushOrDoWork(Callable&& f)
     574             :     {
     575        2474 :         if (parallel) {
     576        2474 :             workerPool.push(std::forward<Callable>(f));
     577        2474 :         } else {
     578           0 :             f(0);
     579             :         }
     580        2474 :     }
     581             : };
     582             : 
     583        3940 : void CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel,
     584             :                                                     std::function<void(const BLSVerificationVectorPtr&)> doneCallback)
     585             : {
     586        3940 :     if (vvecs.empty()) {
     587           0 :         doneCallback(nullptr);
     588           0 :         return;
     589             :     }
     590        3940 :     if (!VerifyVerificationVectors(vvecs)) {
     591           0 :         doneCallback(nullptr);
     592           0 :         return;
     593             :     }
     594             : 
     595        3940 :     auto agg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, parallel, workerPool, std::move(doneCallback));
     596        3940 :     agg->Start();
     597        3940 : }
     598             : 
     599        3940 : std::future<BLSVerificationVectorPtr> CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
     600             : {
     601        3940 :     auto p = BuildFutureDoneCallback<BLSVerificationVectorPtr>();
     602        3940 :     AsyncBuildQuorumVerificationVector(vvecs, parallel, std::move(p.first));
     603        3940 :     return std::move(p.second);
     604        3940 : }
     605             : 
     606        3940 : BLSVerificationVectorPtr CBLSWorker::BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
     607             : {
     608        3940 :     return AsyncBuildQuorumVerificationVector(vvecs, parallel).get();
     609           0 : }
     610             : 
     611             : template <typename T>
     612        3948 : void AsyncAggregateHelper(ctpl::thread_pool& workerPool, Span<T> vec, bool parallel,
     613             :                           std::function<void(const T&)> doneCallback)
     614             : {
     615        3948 :     if (vec.empty()) {
     616           0 :         doneCallback(T());
     617           0 :         return;
     618             :     }
     619        3948 :     if (!VerifyVectorHelper(vec)) {
     620           0 :         doneCallback(T());
     621           0 :         return;
     622             :     }
     623             : 
     624        3948 :     auto agg = std::make_shared<Aggregator<T>>(vec, parallel, workerPool, std::move(doneCallback));
     625        3948 :     agg->Start();
     626        3948 : }
     627             : 
     628        3948 : void CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel,
     629             :                                           std::function<void(const CBLSSecretKey&)> doneCallback)
     630             : {
     631        3948 :     AsyncAggregateHelper(workerPool, secKeys, parallel, std::move(doneCallback));
     632        3948 : }
     633             : 
     634        3948 : std::future<CBLSSecretKey> CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
     635             : {
     636        3948 :     auto p = BuildFutureDoneCallback<CBLSSecretKey>();
     637        3948 :     AsyncAggregateSecretKeys(secKeys, parallel, std::move(p.first));
     638        3948 :     return std::move(p.second);
     639        3948 : }
     640             : 
     641        3948 : CBLSSecretKey CBLSWorker::AggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
     642             : {
     643        3948 :     return AsyncAggregateSecretKeys(secKeys, parallel).get();
     644           0 : }
     645             : 
     646           0 : void CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel,
     647             :                                           std::function<void(const CBLSPublicKey&)> doneCallback)
     648             : {
     649           0 :     AsyncAggregateHelper(workerPool, pubKeys, parallel, std::move(doneCallback));
     650           0 : }
     651             : 
     652           0 : std::future<CBLSPublicKey> CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel)
     653             : {
     654           0 :     auto p = BuildFutureDoneCallback<CBLSPublicKey>();
     655           0 :     AsyncAggregatePublicKeys(pubKeys, parallel, std::move(p.first));
     656           0 :     return std::move(p.second);
     657           0 : }
     658             : 
     659           0 : void CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel,
     660             :                                     std::function<void(const CBLSSignature&)> doneCallback)
     661             : {
     662           0 :     AsyncAggregateHelper(workerPool, sigs, parallel, std::move(doneCallback));
     663           0 : }
     664             : 
     665           0 : std::future<CBLSSignature> CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel)
     666             : {
     667           0 :     auto p = BuildFutureDoneCallback<CBLSSignature>();
     668           0 :     AsyncAggregateSigs(sigs, parallel, std::move(p.first));
     669           0 :     return std::move(p.second);
     670           0 : }
     671             : 
     672       18226 : CBLSPublicKey CBLSWorker::BuildPubKeyShare(const BLSVerificationVectorPtr& vvec, const CBLSId& id)
     673             : {
     674       18226 :     CBLSPublicKey pkShare;
     675       18226 :     pkShare.PublicKeyShare(*vvec, id);
     676       18226 :     return pkShare;
     677             : }
     678             : 
     679        2438 : void CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     680             :                                                bool parallel, bool aggregated, std::function<void(const std::vector<bool>&)> doneCallback)
     681             : {
     682        2438 :     if (!forId.IsValid() || !VerifyVerificationVectors(vvecs)) {
     683           0 :         std::vector<bool> result;
     684           0 :         result.assign(vvecs.size(), false);
     685           0 :         doneCallback(result);
     686             :         return;
     687           0 :     }
     688             : 
     689        2438 :     auto verifier = std::make_shared<ContributionVerifier>(forId, vvecs, skShares, 8, parallel, aggregated, workerPool, std::move(doneCallback));
     690        2438 :     verifier->Start();
     691        2438 : }
     692             : 
     693        2438 : std::future<std::vector<bool> > CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     694             :                                                                           bool parallel, bool aggregated)
     695             : {
     696        2438 :     auto p = BuildFutureDoneCallback<std::vector<bool> >();
     697        2438 :     AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated, std::move(p.first));
     698        2438 :     return std::move(p.second);
     699        2438 : }
     700             : 
     701        2438 : std::vector<bool> CBLSWorker::VerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     702             :                                                        bool parallel, bool aggregated)
     703             : {
     704        2438 :     return AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated).get();
     705           0 : }
     706             : 
     707          36 : std::future<bool> CBLSWorker::AsyncVerifyContributionShare(const CBLSId& forId,
     708             :                                                            const BLSVerificationVectorPtr& vvec,
     709             :                                                            const CBLSSecretKey& skContribution)
     710             : {
     711          36 :     if (!forId.IsValid() || !VerifyVerificationVector(*vvec)) {
     712           0 :         auto p = BuildFutureDoneCallback<bool>();
     713           0 :         p.first(false);
     714           0 :         return std::move(p.second);
     715           0 :     }
     716             : 
     717          72 :     auto f = [&forId, &vvec, &skContribution](int threadId) {
     718          36 :         CBLSPublicKey pk1;
     719          36 :         if (!pk1.PublicKeyShare(*vvec, forId)) {
     720           0 :             return false;
     721             :         }
     722             : 
     723          36 :         CBLSPublicKey pk2 = skContribution.GetPublicKey();
     724          36 :         return pk1 == pk2;
     725          36 :     };
     726          36 :     return workerPool.push(f);
     727          36 : }
     728             : 
     729        8988 : bool CBLSWorker::VerifyVerificationVector(Span<CBLSPublicKey> vvec)
     730             : {
     731        8988 :     return VerifyVectorHelper(vvec);
     732             : }
     733             : 
     734        6378 : bool CBLSWorker::VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs)
     735             : {
     736        6378 :     std::set<uint256> set;
     737       29693 :     for (const auto& vvec : vvecs) {
     738       23315 :         if (vvec == nullptr) {
     739           0 :             return false;
     740             :         }
     741       23315 :         if (vvec->size() != vvecs[0]->size()) {
     742           0 :             return false;
     743             :         }
     744       88847 :         for (size_t j = 0; j < vvec->size(); j++) {
     745       65532 :             if (!(*vvec)[j].IsValid()) {
     746           0 :                 return false;
     747             :             }
     748             :             // check duplicates
     749       65531 :             if (!set.emplace((*vvec)[j].GetHash()).second) {
     750           0 :                 return false;
     751             :             }
     752       65532 :         }
     753             :     }
     754             : 
     755        6378 :     return true;
     756        6380 : }
     757             : 
     758           0 : void CBLSWorker::AsyncSign(const CBLSSecretKey& secKey, const uint256& msgHash, const CBLSWorker::SignDoneCallback& doneCallback)
     759             : {
     760           0 :     workerPool.push([secKey, msgHash, doneCallback](int threadId) {
     761           0 :         doneCallback(secKey.Sign(msgHash, bls::bls_legacy_scheme.load()));
     762           0 :     });
     763           0 : }
     764             : 
     765           0 : void CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash,
     766             :                                 CBLSWorker::SigVerifyDoneCallback doneCallback, CancelCond cancelCond)
     767             : {
     768           0 :     if (!sig.IsValid() || !pubKey.IsValid()) {
     769           0 :         doneCallback(false);
     770           0 :         return;
     771             :     }
     772             : 
     773           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     774             : 
     775           0 :     bool foundDuplicate = std::ranges::any_of(sigVerifyQueue,
     776           0 :                                               [&msgHash](const auto& job) { return job.msgHash == msgHash; });
     777             : 
     778           0 :     if (foundDuplicate) {
     779             :         // batched/aggregated verification does not allow duplicate hashes, so we push what we currently have and start
     780             :         // with a fresh batch
     781           0 :         PushSigVerifyBatch();
     782           0 :     }
     783             : 
     784           0 :     sigVerifyQueue.emplace_back(std::move(doneCallback), std::move(cancelCond), sig, pubKey, msgHash);
     785           0 :     if (sigVerifyBatchesInProgress == 0 || sigVerifyQueue.size() >= SIG_VERIFY_BATCH_SIZE) {
     786           0 :         PushSigVerifyBatch();
     787           0 :     }
     788           0 : }
     789             : 
     790           0 : std::future<bool> CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash, CancelCond cancelCond)
     791             : {
     792           0 :     auto p = BuildFutureDoneCallback2<bool>();
     793           0 :     AsyncVerifySig(sig, pubKey, msgHash, std::move(p.first), std::move(cancelCond));
     794           0 :     return std::move(p.second);
     795           0 : }
     796             : 
     797           0 : bool CBLSWorker::IsAsyncVerifyInProgress()
     798             : {
     799           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     800           0 :     return sigVerifyBatchesInProgress != 0;
     801           0 : }
     802             : 
     803             : // sigVerifyMutex must be held while calling
     804           0 : void CBLSWorker::PushSigVerifyBatch()
     805             : {
     806           0 :     auto f = [this](int threadId, const std::shared_ptr<std::vector<SigVerifyJob> >& _jobs) {
     807           0 :         auto& jobs = *_jobs;
     808           0 :         if (jobs.size() == 1) {
     809           0 :             const auto& job = jobs[0];
     810           0 :             if (!job.cancelCond()) {
     811           0 :                 bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     812           0 :                 job.doneCallback(valid);
     813           0 :             }
     814           0 :             std::unique_lock<std::mutex> l(sigVerifyMutex);
     815           0 :             sigVerifyBatchesInProgress--;
     816           0 :             if (!sigVerifyQueue.empty()) {
     817           0 :                 PushSigVerifyBatch();
     818           0 :             }
     819             :             return;
     820           0 :         }
     821             : 
     822           0 :         CBLSSignature aggSig;
     823           0 :         std::vector<size_t> indexes;
     824           0 :         std::vector<CBLSPublicKey> pubKeys;
     825           0 :         std::vector<uint256> msgHashes;
     826           0 :         indexes.reserve(jobs.size());
     827           0 :         pubKeys.reserve(jobs.size());
     828           0 :         msgHashes.reserve(jobs.size());
     829           0 :         for (size_t i = 0; i < jobs.size(); i++) {
     830           0 :             auto& job = jobs[i];
     831           0 :             if (job.cancelCond()) {
     832           0 :                 continue;
     833             :             }
     834           0 :             if (pubKeys.empty()) {
     835           0 :                 aggSig = job.sig;
     836           0 :             } else {
     837           0 :                 aggSig.AggregateInsecure(job.sig);
     838             :             }
     839           0 :             indexes.emplace_back(i);
     840           0 :             pubKeys.emplace_back(job.pubKey);
     841           0 :             msgHashes.emplace_back(job.msgHash);
     842           0 :         }
     843             : 
     844           0 :         if (!pubKeys.empty()) {
     845           0 :             bool allValid = aggSig.VerifyInsecureAggregated(pubKeys, msgHashes);
     846           0 :             if (allValid) {
     847           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     848           0 :                     jobs[indexes[i]].doneCallback(true);
     849           0 :                 }
     850           0 :             } else {
     851             :                 // one or more sigs were not valid, revert to per-sig verification
     852             :                 // TODO this could be improved if we would cache pairing results in some way as the previous aggregated verification already calculated all the pairings for the hashes
     853           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     854           0 :                     const auto& job = jobs[indexes[i]];
     855           0 :                     bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     856           0 :                     job.doneCallback(valid);
     857           0 :                 }
     858             :             }
     859           0 :         }
     860             : 
     861           0 :         std::unique_lock<std::mutex> l(sigVerifyMutex);
     862           0 :         sigVerifyBatchesInProgress--;
     863           0 :         if (!sigVerifyQueue.empty()) {
     864           0 :             PushSigVerifyBatch();
     865           0 :         }
     866           0 :     };
     867             : 
     868           0 :     auto batch = std::make_shared<std::vector<SigVerifyJob> >(std::move(sigVerifyQueue));
     869           0 :     sigVerifyQueue.reserve(SIG_VERIFY_BATCH_SIZE);
     870             : 
     871           0 :     sigVerifyBatchesInProgress++;
     872           0 :     workerPool.push(f, batch);
     873           0 : }

Generated by: LCOV version 1.16