LCOV - code coverage report
Current view: top level - src/bls - bls_worker.cpp (source / functions) Hit Total Coverage
Test: test_dash_coverage.info Lines: 14 532 2.6 %
Date: 2026-06-25 07:23:51 Functions: 6 254 2.4 %

          Line data    Source code
       1             : // Copyright (c) 2018-2025 The Dash Core developers
       2             : // Distributed under the MIT/X11 software license, see the accompanying
       3             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4             : 
       5             : #include <bls/bls_worker.h>
       6             : #include <hash.h>
       7             : 
       8             : #include <util/system.h>
       9             : 
      10             : #include <memory>
      11             : #include <ranges>
      12             : #include <utility>
      13             : 
      14             : template <typename T>
      15           0 : bool VerifyVectorHelper(Span<T> vec)
      16             : {
      17           0 :     std::set<uint256> set;
      18           0 :     for (const auto& item : vec) {
      19           0 :         if (!item.IsValid())
      20           0 :             return false;
      21             :         // check duplicates
      22           0 :         if (!set.emplace(item.GetHash()).second) {
      23           0 :             return false;
      24             :         }
      25             :     }
      26           0 :     return true;
      27           0 : }
      28             : 
      29             : // Creates a doneCallback and a future. The doneCallback simply finishes the future
      30             : template <typename T>
      31           0 : std::pair<std::function<void(const T&)>, std::future<T> > BuildFutureDoneCallback()
      32             : {
      33           0 :     auto p = std::make_shared<std::promise<T> >();
      34           0 :     std::function<void(const T&)> f = [p](const T& v) {
      35           0 :         p->set_value(v);
      36           0 :     };
      37           0 :     return std::make_pair(std::move(f), p->get_future());
      38           0 : }
      39             : template <typename T>
      40           0 : std::pair<std::function<void(T)>, std::future<T> > BuildFutureDoneCallback2()
      41             : {
      42           0 :     auto p = std::make_shared<std::promise<T> >();
      43           0 :     std::function<void(const T&)> f = [p](T v) {
      44           0 :         p->set_value(v);
      45           0 :     };
      46           0 :     return std::make_pair(std::move(f), p->get_future());
      47           0 : }
      48             : 
      49             : 
      50             : /////
      51             : 
      52         360 : CBLSWorker::CBLSWorker() = default;
      53             : 
      54         360 : CBLSWorker::~CBLSWorker()
      55         180 : {
      56         180 :     Stop();
      57         360 : }
      58             : 
      59         180 : void CBLSWorker::Start(int16_t worker_count)
      60             : {
      61         180 :     assert(worker_count > 0);
      62         180 :     workerPool.resize(worker_count);
      63         180 :     RenameThreadPool(workerPool, "bls-work");
      64         180 : }
      65             : 
      66         360 : void CBLSWorker::Stop()
      67             : {
      68         360 :     workerPool.clear_queue();
      69         360 :     workerPool.stop(true);
      70         360 : }
      71             : 
      72           0 : bool CBLSWorker::GenerateContributions(int quorumThreshold, Span<CBLSId> ids, BLSVerificationVectorPtr& vvecRet, std::vector<CBLSSecretKey>& skSharesRet)
      73             : {
      74           0 :     auto svec = std::vector<CBLSSecretKey>((size_t)quorumThreshold);
      75           0 :     vvecRet = std::make_shared<std::vector<CBLSPublicKey>>((size_t)quorumThreshold);
      76           0 :     skSharesRet.resize(ids.size());
      77             : 
      78           0 :     for (int i = 0; i < quorumThreshold; i++) {
      79           0 :         svec[i].MakeNewKey();
      80           0 :     }
      81           0 :     size_t batchSize = 8;
      82           0 :     std::vector<std::future<bool>> futures;
      83           0 :     futures.reserve((quorumThreshold / batchSize + ids.size() / batchSize) + 2);
      84             : 
      85           0 :     for (size_t i = 0; i < size_t(quorumThreshold); i += batchSize) {
      86           0 :         size_t start = i;
      87           0 :         size_t count = std::min(batchSize, quorumThreshold - start);
      88           0 :         auto f = [&, start, count](int threadId) {
      89           0 :             for (size_t j = start; j < start + count; j++) {
      90           0 :                 (*vvecRet)[j] = svec[j].GetPublicKey();
      91           0 :             }
      92           0 :             return true;
      93             :         };
      94           0 :         futures.emplace_back(workerPool.push(f));
      95           0 :     }
      96             : 
      97           0 :     for (size_t i = 0; i < ids.size(); i += batchSize) {
      98           0 :         size_t start = i;
      99           0 :         size_t count = std::min(batchSize, ids.size() - start);
     100           0 :         auto f = [&, start, count](int threadId) {
     101           0 :             for (size_t j = start; j < start + count; j++) {
     102           0 :                 if (!skSharesRet[j].SecretKeyShare(svec, ids[j])) {
     103           0 :                     return false;
     104             :                 }
     105           0 :             }
     106           0 :             return true;
     107           0 :         };
     108           0 :         futures.emplace_back(workerPool.push(f));
     109           0 :     }
     110           0 :     return std::ranges::all_of(futures, [](auto& f) { return f.get(); });
     111           0 : }
     112             : 
     113             : // aggregates a single vector of BLS objects in parallel
     114             : // the input vector is split into batches and each batch is aggregated in parallel
     115             : // when enough batches are finished to form a new batch, the new batch is queued for further parallel aggregation
     116             : // when no more batches can be created from finished batch results, the final aggregated is created and the doneCallback
     117             : // called.
     118             : // The Aggregator object needs to be created on the heap, and it will delete itself after calling the doneCallback
     119             : // The input vector is not copied into the Aggregator but instead a vector of pointers to the original entries from the
     120             : // input vector is stored. This means that the input vector must stay alive for the whole lifetime of the Aggregator
     121             : template <typename T>
     122             : struct Aggregator : public std::enable_shared_from_this<Aggregator<T>> {
     123           0 :     const size_t BATCH_SIZE{16};
     124             :     std::shared_ptr<std::vector<const T*> > inputVec;
     125             : 
     126             :     bool parallel;
     127             :     ctpl::thread_pool& workerPool;
     128             : 
     129             :     std::mutex m;
     130             :     // items in the queue are all intermediate aggregation results of finished batches.
     131             :     // The intermediate results must be deleted by us again (which we do in SyncAggregateAndPushAggQueue)
     132             :     ctpl::detail::Queue<T*> aggQueue;
     133           0 :     std::atomic<size_t> aggQueueSize{0};
     134             : 
     135             :     // keeps track of currently queued/in-progress batches. If it reaches 0, we are done
     136           0 :     std::atomic<size_t> waitCount{0};
     137             : 
     138             :     using DoneCallback = std::function<void(const T& agg)>;
     139             :     DoneCallback doneCallback;
     140             : 
     141             :     // TP can either be a pointer or a reference
     142             :     template <typename TP>
     143           0 :     Aggregator(Span<TP> _inputSpan, bool _parallel,
     144             :                ctpl::thread_pool& _workerPool,
     145             :                DoneCallback _doneCallback) :
     146           0 :             inputVec(std::make_shared<std::vector<const T*>>(_inputSpan.size())),
     147           0 :             parallel(_parallel),
     148           0 :             workerPool(_workerPool),
     149           0 :             doneCallback(std::move(_doneCallback))
     150           0 :     {
     151           0 :         for (size_t i = 0; i < _inputSpan.size(); i++) {
     152           0 :             (*inputVec)[i] = pointer(_inputSpan[i]);
     153           0 :         }
     154           0 :     }
     155             : 
     156           0 :     const T* pointer(const T& v) { return &v; }
     157           0 :     const T* pointer(const T* v) { return v; }
     158             : 
     159             :     // Starts aggregation.
     160             :     // If parallel=true, then this will return fast, otherwise this will block until aggregation is done
     161           0 :     void Start()
     162             :     {
     163           0 :         size_t batchCount = (inputVec->size() + BATCH_SIZE - 1) / BATCH_SIZE;
     164             : 
     165           0 :         if (!parallel) {
     166           0 :             if (inputVec->size() == 1) {
     167           0 :                 doneCallback(*(*inputVec)[0]);
     168           0 :             } else {
     169           0 :                 doneCallback(SyncAggregate(Span{*inputVec}, 0, inputVec->size()));
     170             :             }
     171           0 :             return;
     172             :         }
     173             : 
     174           0 :         if (batchCount == 1) {
     175             :             // just a single batch of work, take a shortcut.
     176           0 :             auto self(this->shared_from_this());
     177           0 :             PushWork([this, self](int threadId) {
     178           0 :                 if (inputVec->size() == 1) {
     179           0 :                     doneCallback(*(*inputVec)[0]);
     180           0 :                 } else {
     181           0 :                     doneCallback(SyncAggregate(Span{*inputVec}, 0, inputVec->size()));
     182             :                 }
     183           0 :             });
     184             :             return;
     185           0 :         }
     186             : 
     187             :         // increment wait counter as otherwise the first finished async aggregation might signal that we're done
     188           0 :         IncWait();
     189           0 :         for (size_t i = 0; i < batchCount; i++) {
     190           0 :             size_t start = i * BATCH_SIZE;
     191           0 :             size_t count = std::min(BATCH_SIZE, inputVec->size() - start);
     192           0 :             AsyncAggregateAndPushAggQueue(inputVec, start, count, false);
     193           0 :         }
     194             :         // this will decrement the wait counter and in most cases NOT finish, as async work is still in progress
     195           0 :         CheckDone();
     196           0 :     }
     197             : 
     198           0 :     void IncWait()
     199             :     {
     200           0 :         ++waitCount;
     201           0 :     }
     202             : 
     203           0 :     void CheckDone()
     204             :     {
     205           0 :         if (--waitCount == 0) {
     206           0 :             Finish();
     207           0 :         }
     208           0 :     }
     209             : 
     210           0 :     void Finish()
     211             :     {
     212             :         // All async work is done, but we might have items in the aggQueue which are the results of the async
     213             :         // work. This is the case when these did not add up to a new batch. In this case, we have to aggregate
     214             :         // the items into the final result
     215             : 
     216           0 :         std::vector<T*> rem(aggQueueSize);
     217           0 :         for (size_t i = 0; i < rem.size(); i++) {
     218           0 :             T* p = nullptr;
     219           0 :             bool s = aggQueue.pop(p);
     220           0 :             assert(s);
     221           0 :             rem[i] = p;
     222           0 :         }
     223             : 
     224           0 :         T r;
     225           0 :         if (rem.size() == 1) {
     226             :             // just one intermediate result, which is actually the final result
     227           0 :             r = *rem[0];
     228           0 :         } else {
     229             :             // multiple intermediate results left which did not add up to a new batch. aggregate them now
     230           0 :             r = SyncAggregate(Span{rem}, 0, rem.size());
     231             :         }
     232             : 
     233             :         // all items which are left in the queue are intermediate results, so we must delete them
     234           0 :         for (size_t i = 0; i < rem.size(); i++) {
     235           0 :             delete rem[i];
     236           0 :         }
     237           0 :         doneCallback(r);
     238           0 :     }
     239             : 
     240           0 :     void AsyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*>>& vec, size_t start, size_t count, bool del)
     241             :     {
     242           0 :         IncWait();
     243           0 :         auto self(this->shared_from_this());
     244           0 :         PushWork([self, vec, start, count, del](int threadId){
     245           0 :             self->SyncAggregateAndPushAggQueue(vec, start, count, del);
     246           0 :         });
     247           0 :     }
     248             : 
     249           0 :     void SyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*>>& vec, size_t start, size_t count, bool del)
     250             :     {
     251             :         // aggregate vec and push the intermediate result onto the work queue
     252           0 :         PushAggQueue(SyncAggregate(Span{*vec}, start, count));
     253           0 :         if (del) {
     254           0 :             for (size_t i = 0; i < count; i++) {
     255           0 :                 delete (*vec)[start + i];
     256           0 :             }
     257           0 :         }
     258           0 :         CheckDone();
     259           0 :     }
     260             : 
     261           0 :     void PushAggQueue(const T& v)
     262             :     {
     263           0 :         auto copyT = new T(v);
     264             :         try {
     265           0 :             aggQueue.push(copyT);
     266           0 :         } catch (...) {
     267           0 :             delete copyT;
     268           0 :             throw;
     269           0 :         }
     270             : 
     271           0 :         if (++aggQueueSize >= BATCH_SIZE) {
     272             :             // we've collected enough intermediate results to form a new batch.
     273           0 :             std::shared_ptr<std::vector<const T*> > newBatch;
     274             :             {
     275           0 :                 std::unique_lock<std::mutex> l(m);
     276           0 :                 if (aggQueueSize < BATCH_SIZE) {
     277             :                     // some other worker thread grabbed this batch
     278           0 :                     return;
     279             :                 }
     280           0 :                 newBatch = std::make_shared<std::vector<const T*>>(BATCH_SIZE);
     281             :                 // collect items for new batch
     282           0 :                 for (size_t i = 0; i < BATCH_SIZE; i++) {
     283           0 :                     T* p = nullptr;
     284           0 :                     bool s = aggQueue.pop(p);
     285           0 :                     assert(s);
     286           0 :                     (*newBatch)[i] = p;
     287           0 :                 }
     288           0 :                 aggQueueSize -= BATCH_SIZE;
     289           0 :             }
     290             : 
     291             :             // push new batch to work queue. del=true this time as these items are intermediate results and need to be deleted
     292             :             // after aggregation is done
     293           0 :             AsyncAggregateAndPushAggQueue(newBatch, 0, newBatch->size(), true);
     294           0 :         }
     295           0 :     }
     296             : 
     297             :     template <typename TP>
     298           0 :     T SyncAggregate(Span<TP> vec, size_t start, size_t count)
     299             :     {
     300           0 :         T result = *vec[start];
     301           0 :         for (size_t j = 1; j < count; j++) {
     302           0 :             result.AggregateInsecure(*vec[start + j]);
     303           0 :         }
     304           0 :         return result;
     305           0 :     }
     306             : 
     307             :     template <typename Callable>
     308           0 :     void PushWork(Callable&& f)
     309             :     {
     310           0 :         workerPool.push(f);
     311           0 :     }
     312             : };
     313             : 
     314             : // Aggregates multiple input vectors into a single output vector
     315             : // Inputs are in the following form:
     316             : //   [
     317             : //     [a1, b1, c1, d1],
     318             : //     [a2, b2, c2, d2],
     319             : //     [a3, b3, c3, d3],
     320             : //     [a4, b4, c4, d4],
     321             : //   ]
     322             : // The result is in the following form:
     323             : //   [ a1+a2+a3+a4, b1+b2+b3+b4, c1+c2+c3+c4, d1+d2+d3+d4]
     324             : // Same rules for the input vectors apply to the VectorAggregator as for the Aggregator (they must stay alive)
     325             : template <typename T>
     326             : struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T>> {
     327             :     using AggregatorType = Aggregator<T>;
     328             :     using VectorType = std::vector<T>;
     329             :     using VectorPtrType = std::shared_ptr<VectorType>;
     330             :     using VectorVectorType = Span<VectorPtrType>;
     331             :     using DoneCallback = std::function<void(const VectorPtrType& agg)>;
     332             :     DoneCallback doneCallback;
     333             : 
     334             :     VectorVectorType vecs;
     335             :     bool parallel;
     336             :     ctpl::thread_pool& workerPool;
     337             : 
     338           0 :     std::atomic<size_t> doneCount{0};
     339             : 
     340             :     VectorPtrType result;
     341             :     size_t vecSize;
     342             : 
     343           0 :     VectorAggregator(VectorVectorType _vecs,
     344             :                      bool _parallel, ctpl::thread_pool& _workerPool,
     345             :                      DoneCallback _doneCallback) :
     346           0 :             doneCallback(std::move(_doneCallback)),
     347           0 :             vecs(_vecs),
     348           0 :             parallel(_parallel),
     349           0 :             workerPool(_workerPool)
     350           0 :     {
     351           0 :         assert(!vecs.empty());
     352           0 :         vecSize = vecs[0]->size();
     353           0 :         result = std::make_shared<VectorType>(vecSize);
     354           0 :     }
     355             : 
     356           0 :     void Start()
     357             :     {
     358           0 :         for (size_t i = 0; i < vecSize; i++) {
     359           0 :             std::vector<const T*> tmp(vecs.size());
     360           0 :             for (size_t j = 0; j < vecs.size(); j++) {
     361           0 :                 tmp[j] = &(*vecs[j])[i];
     362           0 :             }
     363             : 
     364           0 :             auto self(this->shared_from_this());
     365           0 :             auto aggregator = std::make_shared<AggregatorType>(Span{tmp}, parallel, workerPool, [self, i](const T& agg) {self->CheckDone(agg, i);});
     366           0 :             aggregator->Start();
     367           0 :         }
     368           0 :     }
     369             : 
     370           0 :     void CheckDone(const T& agg, size_t idx)
     371             :     {
     372           0 :         (*result)[idx] = agg;
     373           0 :         if (++doneCount == vecSize) {
     374           0 :             doneCallback(result);
     375           0 :         }
     376           0 :     }
     377             : };
     378             : 
     379             : // See comment of AsyncVerifyContributionShares for a description on what this does
     380             : // Same rules as in Aggregator apply for the inputs
     381             : struct ContributionVerifier : public std::enable_shared_from_this<ContributionVerifier> {
     382           0 :     struct BatchState {
     383             :         size_t start;
     384             :         size_t count;
     385             : 
     386             :         BLSVerificationVectorPtr vvec;
     387             :         CBLSSecretKey skShare;
     388             : 
     389             :         // starts with 0 and is incremented if either vvec or skShare aggregation finishes. If it reaches 2, we know
     390             :         // that aggregation for this batch is fully done. We can then start verification.
     391             :         std::unique_ptr<std::atomic<int> > aggDone;
     392             : 
     393             :         // we can't directly update a vector<bool> in parallel
     394             :         // as vector<bool> is not thread safe (uses bitsets internally)
     395             :         // so we must use vector<char> temporarily and concatenate/convert
     396             :         // each batch result into a final vector<bool>
     397             :         std::vector<char> verifyResults;
     398             :     };
     399             : 
     400             :     CBLSId forId;
     401             :     Span<BLSVerificationVectorPtr> vvecs;
     402             :     Span<CBLSSecretKey> skShares;
     403             :     size_t batchSize;
     404             :     bool parallel;
     405             :     bool aggregated;
     406             : 
     407             :     ctpl::thread_pool& workerPool;
     408             : 
     409           0 :     size_t batchCount{1};
     410             :     size_t verifyCount;
     411             : 
     412             :     std::vector<BatchState> batchStates;
     413           0 :     std::atomic<size_t> verifyDoneCount{0};
     414             :     std::function<void(const std::vector<bool>&)> doneCallback;
     415             : 
     416           0 :     ContributionVerifier(CBLSId _forId, Span<BLSVerificationVectorPtr> _vvecs,
     417             :                          Span<CBLSSecretKey> _skShares, size_t _batchSize,
     418             :                          bool _parallel, bool _aggregated, ctpl::thread_pool& _workerPool,
     419             :                          std::function<void(const std::vector<bool>&)> _doneCallback) :
     420           0 :         forId(std::move(_forId)),
     421           0 :         vvecs(_vvecs),
     422           0 :         skShares(_skShares),
     423           0 :         batchSize(_batchSize),
     424           0 :         parallel(_parallel),
     425           0 :         aggregated(_aggregated),
     426           0 :         workerPool(_workerPool),
     427           0 :         verifyCount(_vvecs.size()),
     428           0 :         doneCallback(std::move(_doneCallback))
     429           0 :     {
     430           0 :     }
     431             : 
     432           0 :     void Start()
     433             :     {
     434           0 :         if (!aggregated) {
     435             :             // treat all inputs as one large batch
     436           0 :             batchSize = vvecs.size();
     437           0 :         } else {
     438           0 :             batchCount = (vvecs.size() + batchSize - 1) / batchSize;
     439             :         }
     440             : 
     441           0 :         batchStates.resize(batchCount);
     442           0 :         for (size_t i = 0; i < batchCount; i++) {
     443           0 :             auto& batchState = batchStates[i];
     444             : 
     445           0 :             batchState.aggDone = std::make_unique<std::atomic<int>>(0);
     446           0 :             batchState.start = i * batchSize;
     447           0 :             batchState.count = std::min(batchSize, vvecs.size() - batchState.start);
     448           0 :             batchState.verifyResults.assign(batchState.count, 0);
     449           0 :         }
     450             : 
     451           0 :         if (aggregated) {
     452           0 :             size_t batchCount2 = batchCount; // 'this' might get deleted while we're still looping
     453           0 :             for (size_t i = 0; i < batchCount2; i++) {
     454           0 :                 AsyncAggregate(i);
     455           0 :             }
     456           0 :         } else {
     457             :             // treat all inputs as a single batch and verify one-by-one
     458           0 :             AsyncVerifyBatchOneByOne(0);
     459             :         }
     460           0 :     }
     461             : 
     462           0 :     void Finish()
     463             :     {
     464           0 :         size_t batchIdx = 0;
     465           0 :         std::vector<bool> result(vvecs.size());
     466           0 :         for (size_t i = 0; i < vvecs.size(); i += batchSize) {
     467           0 :             const auto& batchState = batchStates[batchIdx++];
     468           0 :             for (size_t j = 0; j < batchState.count; j++) {
     469           0 :                 result[batchState.start + j] = batchState.verifyResults[j] != 0;
     470           0 :             }
     471           0 :         }
     472           0 :         doneCallback(result);
     473           0 :     }
     474             : 
     475           0 :     void AsyncAggregate(size_t batchIdx)
     476             :     {
     477           0 :         auto& batchState = batchStates[batchIdx];
     478             : 
     479             :         // aggregate vvecs and skShares of batch in parallel
     480           0 :         auto self(this->shared_from_this());
     481           0 :         auto vvecAgg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const BLSVerificationVectorPtr& vvec) {HandleAggVvecDone(batchIdx, vvec);});
     482           0 :         auto skShareAgg = std::make_shared<Aggregator<CBLSSecretKey>>(Span{skShares}.subspan(batchState.start, batchState.count), parallel, workerPool, [this, self, batchIdx] (const CBLSSecretKey& skShare) {HandleAggSkShareDone(batchIdx, skShare);});
     483             : 
     484           0 :         vvecAgg->Start();
     485           0 :         skShareAgg->Start();
     486           0 :     }
     487             : 
     488           0 :     void HandleAggVvecDone(size_t batchIdx, const BLSVerificationVectorPtr& vvec)
     489             :     {
     490           0 :         auto& batchState = batchStates[batchIdx];
     491           0 :         batchState.vvec = vvec;
     492           0 :         if (++(*batchState.aggDone) == 2) {
     493           0 :             HandleAggDone(batchIdx);
     494           0 :         }
     495           0 :     }
     496           0 :     void HandleAggSkShareDone(size_t batchIdx, const CBLSSecretKey& skShare)
     497             :     {
     498           0 :         auto& batchState = batchStates[batchIdx];
     499           0 :         batchState.skShare = skShare;
     500           0 :         if (++(*batchState.aggDone) == 2) {
     501           0 :             HandleAggDone(batchIdx);
     502           0 :         }
     503           0 :     }
     504             : 
     505           0 :     void HandleVerifyDone(size_t count)
     506             :     {
     507           0 :         size_t c = verifyDoneCount += count;
     508           0 :         if (c == verifyCount) {
     509           0 :             Finish();
     510           0 :         }
     511           0 :     }
     512             : 
     513           0 :     void HandleAggDone(size_t batchIdx)
     514             :     {
     515           0 :         auto& batchState = batchStates[batchIdx];
     516             : 
     517           0 :         if (batchState.vvec == nullptr || batchState.vvec->empty() || !batchState.skShare.IsValid()) {
     518             :             // something went wrong while aggregating and there is nothing we can do now except mark the whole batch as failed
     519             :             // this can only happen if inputs were invalid in some way
     520           0 :             batchState.verifyResults.assign(batchState.count, 0);
     521           0 :             HandleVerifyDone(batchState.count);
     522           0 :             return;
     523             :         }
     524             : 
     525           0 :         AsyncAggregatedVerifyBatch(batchIdx);
     526           0 :     }
     527             : 
     528           0 :     void AsyncAggregatedVerifyBatch(size_t batchIdx)
     529             :     {
     530           0 :         auto self(this->shared_from_this());
     531           0 :         auto f = [this, self, batchIdx](int threadId) {
     532           0 :             auto& batchState = batchStates[batchIdx];
     533           0 :             bool result = Verify(batchState.vvec, batchState.skShare);
     534           0 :             if (result) {
     535             :                 // whole batch is valid
     536           0 :                 batchState.verifyResults.assign(batchState.count, 1);
     537           0 :                 HandleVerifyDone(batchState.count);
     538           0 :             } else {
     539             :                 // at least one entry in the batch is invalid, revert to per-contribution verification (but parallelized)
     540           0 :                 AsyncVerifyBatchOneByOne(batchIdx);
     541             :             }
     542           0 :         };
     543           0 :         PushOrDoWork(std::move(f));
     544           0 :     }
     545             : 
     546           0 :     void AsyncVerifyBatchOneByOne(size_t batchIdx)
     547             :     {
     548           0 :         size_t count = batchStates[batchIdx].count;
     549           0 :         batchStates[batchIdx].verifyResults.assign(count, 0);
     550           0 :         for (size_t i = 0; i < count; i++) {
     551           0 :             auto self(this->shared_from_this());
     552           0 :             auto f = [this, self, i, batchIdx](int threadId) {
     553           0 :                 auto& batchState = batchStates[batchIdx];
     554           0 :                 batchState.verifyResults[i] = Verify(vvecs[batchState.start + i], skShares[batchState.start + i]);
     555           0 :                 HandleVerifyDone(1);
     556           0 :             };
     557           0 :             PushOrDoWork(std::move(f));
     558           0 :         }
     559           0 :     }
     560             : 
     561           0 :     bool Verify(const BLSVerificationVectorPtr& vvec, const CBLSSecretKey& skShare) const
     562             :     {
     563           0 :         CBLSPublicKey pk1;
     564           0 :         if (!pk1.PublicKeyShare(*vvec, forId)) {
     565           0 :             return false;
     566             :         }
     567             : 
     568           0 :         CBLSPublicKey pk2 = skShare.GetPublicKey();
     569           0 :         return pk1 == pk2;
     570           0 :     }
     571             : 
     572             :     template <typename Callable>
     573           0 :     void PushOrDoWork(Callable&& f)
     574             :     {
     575           0 :         if (parallel) {
     576           0 :             workerPool.push(std::forward<Callable>(f));
     577           0 :         } else {
     578           0 :             f(0);
     579             :         }
     580           0 :     }
     581             : };
     582             : 
     583           0 : void CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel,
     584             :                                                     std::function<void(const BLSVerificationVectorPtr&)> doneCallback)
     585             : {
     586           0 :     if (vvecs.empty()) {
     587           0 :         doneCallback(nullptr);
     588           0 :         return;
     589             :     }
     590           0 :     if (!VerifyVerificationVectors(vvecs)) {
     591           0 :         doneCallback(nullptr);
     592           0 :         return;
     593             :     }
     594             : 
     595           0 :     auto agg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, parallel, workerPool, std::move(doneCallback));
     596           0 :     agg->Start();
     597           0 : }
     598             : 
     599           0 : std::future<BLSVerificationVectorPtr> CBLSWorker::AsyncBuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
     600             : {
     601           0 :     auto p = BuildFutureDoneCallback<BLSVerificationVectorPtr>();
     602           0 :     AsyncBuildQuorumVerificationVector(vvecs, parallel, std::move(p.first));
     603           0 :     return std::move(p.second);
     604           0 : }
     605             : 
     606           0 : BLSVerificationVectorPtr CBLSWorker::BuildQuorumVerificationVector(Span<BLSVerificationVectorPtr> vvecs, bool parallel)
     607             : {
     608           0 :     return AsyncBuildQuorumVerificationVector(vvecs, parallel).get();
     609           0 : }
     610             : 
     611             : template <typename T>
     612           0 : void AsyncAggregateHelper(ctpl::thread_pool& workerPool, Span<T> vec, bool parallel,
     613             :                           std::function<void(const T&)> doneCallback)
     614             : {
     615           0 :     if (vec.empty()) {
     616           0 :         doneCallback(T());
     617           0 :         return;
     618             :     }
     619           0 :     if (!VerifyVectorHelper(vec)) {
     620           0 :         doneCallback(T());
     621           0 :         return;
     622             :     }
     623             : 
     624           0 :     auto agg = std::make_shared<Aggregator<T>>(vec, parallel, workerPool, std::move(doneCallback));
     625           0 :     agg->Start();
     626           0 : }
     627             : 
     628           0 : void CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel,
     629             :                                           std::function<void(const CBLSSecretKey&)> doneCallback)
     630             : {
     631           0 :     AsyncAggregateHelper(workerPool, secKeys, parallel, std::move(doneCallback));
     632           0 : }
     633             : 
     634           0 : std::future<CBLSSecretKey> CBLSWorker::AsyncAggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
     635             : {
     636           0 :     auto p = BuildFutureDoneCallback<CBLSSecretKey>();
     637           0 :     AsyncAggregateSecretKeys(secKeys, parallel, std::move(p.first));
     638           0 :     return std::move(p.second);
     639           0 : }
     640             : 
     641           0 : CBLSSecretKey CBLSWorker::AggregateSecretKeys(Span<CBLSSecretKey> secKeys, bool parallel)
     642             : {
     643           0 :     return AsyncAggregateSecretKeys(secKeys, parallel).get();
     644           0 : }
     645             : 
     646           0 : void CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel,
     647             :                                           std::function<void(const CBLSPublicKey&)> doneCallback)
     648             : {
     649           0 :     AsyncAggregateHelper(workerPool, pubKeys, parallel, std::move(doneCallback));
     650           0 : }
     651             : 
     652           0 : std::future<CBLSPublicKey> CBLSWorker::AsyncAggregatePublicKeys(Span<CBLSPublicKey> pubKeys, bool parallel)
     653             : {
     654           0 :     auto p = BuildFutureDoneCallback<CBLSPublicKey>();
     655           0 :     AsyncAggregatePublicKeys(pubKeys, parallel, std::move(p.first));
     656           0 :     return std::move(p.second);
     657           0 : }
     658             : 
     659           0 : void CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel,
     660             :                                     std::function<void(const CBLSSignature&)> doneCallback)
     661             : {
     662           0 :     AsyncAggregateHelper(workerPool, sigs, parallel, std::move(doneCallback));
     663           0 : }
     664             : 
     665           0 : std::future<CBLSSignature> CBLSWorker::AsyncAggregateSigs(Span<CBLSSignature> sigs, bool parallel)
     666             : {
     667           0 :     auto p = BuildFutureDoneCallback<CBLSSignature>();
     668           0 :     AsyncAggregateSigs(sigs, parallel, std::move(p.first));
     669           0 :     return std::move(p.second);
     670           0 : }
     671             : 
     672           0 : CBLSPublicKey CBLSWorker::BuildPubKeyShare(const BLSVerificationVectorPtr& vvec, const CBLSId& id)
     673             : {
     674           0 :     CBLSPublicKey pkShare;
     675           0 :     pkShare.PublicKeyShare(*vvec, id);
     676           0 :     return pkShare;
     677             : }
     678             : 
     679           0 : void CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     680             :                                                bool parallel, bool aggregated, std::function<void(const std::vector<bool>&)> doneCallback)
     681             : {
     682           0 :     if (!forId.IsValid() || !VerifyVerificationVectors(vvecs)) {
     683           0 :         std::vector<bool> result;
     684           0 :         result.assign(vvecs.size(), false);
     685           0 :         doneCallback(result);
     686             :         return;
     687           0 :     }
     688             : 
     689           0 :     auto verifier = std::make_shared<ContributionVerifier>(forId, vvecs, skShares, 8, parallel, aggregated, workerPool, std::move(doneCallback));
     690           0 :     verifier->Start();
     691           0 : }
     692             : 
     693           0 : std::future<std::vector<bool> > CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     694             :                                                                           bool parallel, bool aggregated)
     695             : {
     696           0 :     auto p = BuildFutureDoneCallback<std::vector<bool> >();
     697           0 :     AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated, std::move(p.first));
     698           0 :     return std::move(p.second);
     699           0 : }
     700             : 
     701           0 : std::vector<bool> CBLSWorker::VerifyContributionShares(const CBLSId& forId, Span<BLSVerificationVectorPtr> vvecs, Span<CBLSSecretKey> skShares,
     702             :                                                        bool parallel, bool aggregated)
     703             : {
     704           0 :     return AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated).get();
     705           0 : }
     706             : 
     707           0 : std::future<bool> CBLSWorker::AsyncVerifyContributionShare(const CBLSId& forId,
     708             :                                                            const BLSVerificationVectorPtr& vvec,
     709             :                                                            const CBLSSecretKey& skContribution)
     710             : {
     711           0 :     if (!forId.IsValid() || !VerifyVerificationVector(*vvec)) {
     712           0 :         auto p = BuildFutureDoneCallback<bool>();
     713           0 :         p.first(false);
     714           0 :         return std::move(p.second);
     715           0 :     }
     716             : 
     717           0 :     auto f = [&forId, &vvec, &skContribution](int threadId) {
     718           0 :         CBLSPublicKey pk1;
     719           0 :         if (!pk1.PublicKeyShare(*vvec, forId)) {
     720           0 :             return false;
     721             :         }
     722             : 
     723           0 :         CBLSPublicKey pk2 = skContribution.GetPublicKey();
     724           0 :         return pk1 == pk2;
     725           0 :     };
     726           0 :     return workerPool.push(f);
     727           0 : }
     728             : 
     729           0 : bool CBLSWorker::VerifyVerificationVector(Span<CBLSPublicKey> vvec)
     730             : {
     731           0 :     return VerifyVectorHelper(vvec);
     732             : }
     733             : 
     734           0 : bool CBLSWorker::VerifyVerificationVectors(Span<BLSVerificationVectorPtr> vvecs)
     735             : {
     736           0 :     std::set<uint256> set;
     737           0 :     for (const auto& vvec : vvecs) {
     738           0 :         if (vvec == nullptr) {
     739           0 :             return false;
     740             :         }
     741           0 :         if (vvec->size() != vvecs[0]->size()) {
     742           0 :             return false;
     743             :         }
     744           0 :         for (size_t j = 0; j < vvec->size(); j++) {
     745           0 :             if (!(*vvec)[j].IsValid()) {
     746           0 :                 return false;
     747             :             }
     748             :             // check duplicates
     749           0 :             if (!set.emplace((*vvec)[j].GetHash()).second) {
     750           0 :                 return false;
     751             :             }
     752           0 :         }
     753             :     }
     754             : 
     755           0 :     return true;
     756           0 : }
     757             : 
     758           0 : void CBLSWorker::AsyncSign(const CBLSSecretKey& secKey, const uint256& msgHash, const CBLSWorker::SignDoneCallback& doneCallback)
     759             : {
     760           0 :     workerPool.push([secKey, msgHash, doneCallback](int threadId) {
     761           0 :         doneCallback(secKey.Sign(msgHash, bls::bls_legacy_scheme.load()));
     762           0 :     });
     763           0 : }
     764             : 
     765           0 : void CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash,
     766             :                                 CBLSWorker::SigVerifyDoneCallback doneCallback, CancelCond cancelCond)
     767             : {
     768           0 :     if (!sig.IsValid() || !pubKey.IsValid()) {
     769           0 :         doneCallback(false);
     770           0 :         return;
     771             :     }
     772             : 
     773           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     774             : 
     775           0 :     bool foundDuplicate = std::ranges::any_of(sigVerifyQueue,
     776           0 :                                               [&msgHash](const auto& job) { return job.msgHash == msgHash; });
     777             : 
     778           0 :     if (foundDuplicate) {
     779             :         // batched/aggregated verification does not allow duplicate hashes, so we push what we currently have and start
     780             :         // with a fresh batch
     781           0 :         PushSigVerifyBatch();
     782           0 :     }
     783             : 
     784           0 :     sigVerifyQueue.emplace_back(std::move(doneCallback), std::move(cancelCond), sig, pubKey, msgHash);
     785           0 :     if (sigVerifyBatchesInProgress == 0 || sigVerifyQueue.size() >= SIG_VERIFY_BATCH_SIZE) {
     786           0 :         PushSigVerifyBatch();
     787           0 :     }
     788           0 : }
     789             : 
     790           0 : std::future<bool> CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash, CancelCond cancelCond)
     791             : {
     792           0 :     auto p = BuildFutureDoneCallback2<bool>();
     793           0 :     AsyncVerifySig(sig, pubKey, msgHash, std::move(p.first), std::move(cancelCond));
     794           0 :     return std::move(p.second);
     795           0 : }
     796             : 
     797           0 : bool CBLSWorker::IsAsyncVerifyInProgress()
     798             : {
     799           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     800           0 :     return sigVerifyBatchesInProgress != 0;
     801           0 : }
     802             : 
     803             : // sigVerifyMutex must be held while calling
     804           0 : void CBLSWorker::PushSigVerifyBatch()
     805             : {
     806           0 :     auto f = [this](int threadId, const std::shared_ptr<std::vector<SigVerifyJob> >& _jobs) {
     807           0 :         auto& jobs = *_jobs;
     808           0 :         if (jobs.size() == 1) {
     809           0 :             const auto& job = jobs[0];
     810           0 :             if (!job.cancelCond()) {
     811           0 :                 bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     812           0 :                 job.doneCallback(valid);
     813           0 :             }
     814           0 :             std::unique_lock<std::mutex> l(sigVerifyMutex);
     815           0 :             sigVerifyBatchesInProgress--;
     816           0 :             if (!sigVerifyQueue.empty()) {
     817           0 :                 PushSigVerifyBatch();
     818           0 :             }
     819             :             return;
     820           0 :         }
     821             : 
     822           0 :         CBLSSignature aggSig;
     823           0 :         std::vector<size_t> indexes;
     824           0 :         std::vector<CBLSPublicKey> pubKeys;
     825           0 :         std::vector<uint256> msgHashes;
     826           0 :         indexes.reserve(jobs.size());
     827           0 :         pubKeys.reserve(jobs.size());
     828           0 :         msgHashes.reserve(jobs.size());
     829           0 :         for (size_t i = 0; i < jobs.size(); i++) {
     830           0 :             auto& job = jobs[i];
     831           0 :             if (job.cancelCond()) {
     832           0 :                 continue;
     833             :             }
     834           0 :             if (pubKeys.empty()) {
     835           0 :                 aggSig = job.sig;
     836           0 :             } else {
     837           0 :                 aggSig.AggregateInsecure(job.sig);
     838             :             }
     839           0 :             indexes.emplace_back(i);
     840           0 :             pubKeys.emplace_back(job.pubKey);
     841           0 :             msgHashes.emplace_back(job.msgHash);
     842           0 :         }
     843             : 
     844           0 :         if (!pubKeys.empty()) {
     845           0 :             bool allValid = aggSig.VerifyInsecureAggregated(pubKeys, msgHashes);
     846           0 :             if (allValid) {
     847           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     848           0 :                     jobs[indexes[i]].doneCallback(true);
     849           0 :                 }
     850           0 :             } else {
     851             :                 // one or more sigs were not valid, revert to per-sig verification
     852             :                 // TODO this could be improved if we would cache pairing results in some way as the previous aggregated verification already calculated all the pairings for the hashes
     853           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     854           0 :                     const auto& job = jobs[indexes[i]];
     855           0 :                     bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     856           0 :                     job.doneCallback(valid);
     857           0 :                 }
     858             :             }
     859           0 :         }
     860             : 
     861           0 :         std::unique_lock<std::mutex> l(sigVerifyMutex);
     862           0 :         sigVerifyBatchesInProgress--;
     863           0 :         if (!sigVerifyQueue.empty()) {
     864           0 :             PushSigVerifyBatch();
     865           0 :         }
     866           0 :     };
     867             : 
     868           0 :     auto batch = std::make_shared<std::vector<SigVerifyJob> >(std::move(sigVerifyQueue));
     869           0 :     sigVerifyQueue.reserve(SIG_VERIFY_BATCH_SIZE);
     870             : 
     871           0 :     sigVerifyBatchesInProgress++;
     872           0 :     workerPool.push(f, batch);
     873           0 : }

Generated by: LCOV version 1.16