LCOV - code coverage report
Current view: top level - src/coinjoin - util.cpp (source / functions) Hit Total Coverage
Test: test_dash_coverage.info Lines: 129 199 64.8 %
Date: 2026-06-25 07:23:51 Functions: 23 32 71.9 %

          Line data    Source code
       1             : // Copyright (c) 2014-2025 The Dash Core developers
       2             : // Distributed under the MIT/X11 software license, see the accompanying
       3             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4             : 
       5             : #include <coinjoin/util.h>
       6             : #include <policy/fees.h>
       7             : #include <policy/policy.h>
       8             : #include <util/translation.h>
       9             : #include <wallet/fees.h>
      10             : #include <wallet/spend.h>
      11             : #include <wallet/wallet.h>
      12             : #include <wallet/walletutil.h>
      13             : 
      14             : #include <numeric>
      15             : 
      16             : using wallet::CompactTallyItem;
      17             : using wallet::CRecipient;
      18             : using wallet::CWallet;
      19             : using wallet::FEATURE_COMPRPUBKEY;
      20             : using wallet::GetDiscardRate;
      21             : using wallet::RANDOM_CHANGE_POSITION;
      22             : using wallet::WalletBatch;
      23             : 
      24         228 : inline unsigned int GetSizeOfCompactSizeDiff(uint64_t nSizePrev, uint64_t nSizeNew)
      25             : {
      26         228 :     assert(nSizePrev <= nSizeNew);
      27         228 :     return ::GetSizeOfCompactSize(nSizeNew) - ::GetSizeOfCompactSize(nSizePrev);
      28             : }
      29             : 
      30           0 : CKeyHolder::CKeyHolder(CWallet* pwallet) :
      31           0 :     reserveDestination(pwallet)
      32           0 : {
      33           0 :     auto dest_opt = reserveDestination.GetReservedDestination(false);
      34           0 :     assert(dest_opt);
      35           0 :     dest = *dest_opt;
      36           0 : }
      37             : 
      38           0 : void CKeyHolder::KeepKey()
      39             : {
      40           0 :     reserveDestination.KeepDestination();
      41           0 : }
      42             : 
      43           0 : void CKeyHolder::ReturnKey()
      44             : {
      45           0 :     reserveDestination.ReturnDestination();
      46           0 : }
      47             : 
      48           0 : CScript CKeyHolder::GetScriptForDestination() const
      49             : {
      50           0 :     return ::GetScriptForDestination(dest);
      51             : }
      52             : 
      53             : 
      54           0 : CScript CKeyHolderStorage::AddKey(CWallet* pwallet)
      55             : {
      56           0 :     auto keyHolderPtr = std::make_unique<CKeyHolder>(pwallet);
      57           0 :     auto script = keyHolderPtr->GetScriptForDestination();
      58             : 
      59           0 :     LOCK(cs_storage);
      60           0 :     storage.emplace_back(std::move(keyHolderPtr));
      61           0 :     LogPrint(BCLog::COINJOIN, "CKeyHolderStorage::%s -- storage size %lld\n", __func__, storage.size());
      62           0 :     return script;
      63           0 : }
      64             : 
      65           0 : void CKeyHolderStorage::KeepAll()
      66             : {
      67           0 :     std::vector<std::unique_ptr<CKeyHolder> > tmp;
      68             :     {
      69             :         // don't hold cs_storage while calling KeepKey(), which might lock cs_wallet
      70           0 :         LOCK(cs_storage);
      71           0 :         std::swap(storage, tmp);
      72           0 :     }
      73             : 
      74           0 :     if (!tmp.empty()) {
      75           0 :         for (const auto& key : tmp) {
      76           0 :             key->KeepKey();
      77             :         }
      78           0 :         LogPrint(BCLog::COINJOIN, "CKeyHolderStorage::%s -- %lld keys kept\n", __func__, tmp.size());
      79           0 :     }
      80           0 : }
      81             : 
      82           0 : void CKeyHolderStorage::ReturnAll()
      83             : {
      84           0 :     std::vector<std::unique_ptr<CKeyHolder> > tmp;
      85             :     {
      86             :         // don't hold cs_storage while calling ReturnKey(), which might lock cs_wallet
      87           0 :         LOCK(cs_storage);
      88           0 :         std::swap(storage, tmp);
      89           0 :     }
      90             : 
      91           0 :     if (!tmp.empty()) {
      92           0 :         for (const auto& key : tmp) {
      93           0 :             key->ReturnKey();
      94             :         }
      95           0 :         LogPrint(BCLog::COINJOIN, "CKeyHolderStorage::%s -- %lld keys returned\n", __func__, tmp.size());
      96           0 :     }
      97           0 : }
      98             : 
      99         303 : CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, CWallet& wallet, CAmount nAmountIn) :
     100         101 :     pTxBuilder(pTxBuilderIn),
     101         101 :     dest(&wallet),
     102         101 :     nAmount(nAmountIn)
     103         101 : {
     104         101 :     assert(pTxBuilder);
     105         101 :     LOCK(wallet.cs_wallet);
     106         101 :     auto dest_opt = dest.GetReservedDestination(false);
     107         101 :     assert(dest_opt);
     108         101 :     script = ::GetScriptForDestination(*dest_opt);
     109         202 : }
     110             : 
     111           7 : bool CTransactionBuilderOutput::UpdateAmount(const CAmount nNewAmount)
     112             : {
     113           7 :     if (nNewAmount <= 0 || nNewAmount - nAmount > pTxBuilder->GetAmountLeft()) {
     114           3 :         return false;
     115             :     }
     116           4 :     nAmount = nNewAmount;
     117           4 :     return true;
     118           7 : }
     119             : 
     120           6 : CTransactionBuilder::CTransactionBuilder(CWallet& wallet, const CompactTallyItem& tallyItemIn) :
     121           2 :     m_wallet(wallet),
     122           2 :     dummyReserveDestination(&wallet),
     123           2 :     tallyItem(tallyItemIn)
     124           2 : {
     125             :     // Generate a feerate which will be used to consider if the remainder is dust and will go into fees or not
     126             :     coinControl.m_discard_feerate = ::GetDiscardRate(m_wallet);
     127             :     // Generate a feerate which will be used by calculations of this class and also by CWallet::CreateTransaction
     128             :     coinControl.m_feerate = std::max(GetRequiredFeeRate(m_wallet), m_wallet.m_pay_tx_fee);
     129             :     // If wallet does not have the avoid-reuse feature enabled, keep legacy
     130             :     // behavior: force change to go back to the origin address. When
     131             :     // WALLET_FLAG_AVOID_REUSE is enabled, let the wallet select a fresh
     132             :     // change destination to avoid address reuse.
     133             :     if (!m_wallet.IsWalletFlagSet(wallet::WALLET_FLAG_AVOID_REUSE)) {
     134             :         coinControl.destChange = tallyItemIn.txdest;
     135             :     }
     136             :     // Only allow tallyItems inputs for tx creation
     137             :     coinControl.m_allow_other_inputs = false;
     138             :     // Create dummy tx to calculate the exact required fees upfront for accurate amount and fee calculations
     139             :     CMutableTransaction dummyTx;
     140             :     // Select all tallyItem outputs in the coinControl so that CreateTransaction knows what to use
     141             :     for (const auto& outpoint : tallyItem.outpoints) {
     142             :         coinControl.Select(outpoint);
     143             :         dummyTx.vin.emplace_back(outpoint, CScript());
     144             :     }
     145             :     // Get a comparable dummy scriptPubKey, avoid writing/flushing to the actual wallet db
     146             :     CScript dummyScript;
     147             :     {
     148             :         LOCK(m_wallet.cs_wallet);
     149             :         WalletBatch dummyBatch(m_wallet.GetDatabase(), false);
     150             :         dummyBatch.TxnBegin();
     151             :         CKey secret;
     152             :         secret.MakeNewKey(m_wallet.CanSupportFeature(FEATURE_COMPRPUBKEY));
     153             :         CPubKey dummyPubkey = secret.GetPubKey();
     154             :         dummyBatch.TxnAbort();
     155             :         dummyScript = ::GetScriptForDestination(PKHash(dummyPubkey));
     156             :         // Calculate required bytes for the dummy signed tx with tallyItem's inputs only
     157             :         nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), &m_wallet, /*coin_control=*/nullptr);
     158             :     }
     159             :     // Calculate the output size
     160             :     nBytesOutput = ::GetSerializeSize(CTxOut(0, dummyScript), PROTOCOL_VERSION);
     161             :     // Just to make sure..
     162             :     Clear();
     163           2 : }
     164             : 
     165           4 : CTransactionBuilder::~CTransactionBuilder()
     166           2 : {
     167           2 :     Clear();
     168           4 : }
     169             : 
     170           4 : void CTransactionBuilder::Clear()
     171             : {
     172           4 :     std::vector<std::unique_ptr<CTransactionBuilderOutput>> vecOutputsTmp;
     173             :     {
     174             :         // Don't hold cs_outputs while clearing the outputs which might indirectly call lock cs_wallet
     175           4 :         LOCK(cs_outputs);
     176           4 :         std::swap(vecOutputs, vecOutputsTmp);
     177           4 :         vecOutputs.clear();
     178           4 :     }
     179             : 
     180         105 :     for (auto& key : vecOutputsTmp) {
     181         101 :         if (fKeepKeys) {
     182         101 :             key->KeepKey();
     183         101 :         } else {
     184           0 :             key->ReturnKey();
     185             :         }
     186             :     }
     187             :     // Always return this key just to make sure..
     188           4 :     dummyReserveDestination.ReturnDestination();
     189           4 : }
     190             : 
     191         107 : bool CTransactionBuilder::CouldAddOutput(CAmount nAmountOutput) const
     192             : {
     193         107 :     if (nAmountOutput < 0) {
     194           2 :         return false;
     195             :     }
     196             :     // Adding another output can change the serialized size of the vout size hence + GetSizeOfCompactSizeDiff()
     197         105 :     unsigned int nBytes = GetBytesTotal() + nBytesOutput + GetSizeOfCompactSizeDiff(1);
     198         105 :     return GetAmountLeft(GetAmountInitial(), GetAmountUsed() + nAmountOutput, GetFee(nBytes)) >= 0;
     199         107 : }
     200             : 
     201           2 : bool CTransactionBuilder::CouldAddOutputs(const std::vector<CAmount>& vecOutputAmounts) const
     202             : {
     203           2 :     CAmount nAmountAdditional{0};
     204           2 :     assert(vecOutputAmounts.size() < std::numeric_limits<int>::max());
     205           2 :     int nBytesAdditional = nBytesOutput * int(vecOutputAmounts.size());
     206           8 :     for (const auto nAmountOutput : vecOutputAmounts) {
     207           6 :         if (nAmountOutput < 0) {
     208           0 :             return false;
     209             :         }
     210           6 :         nAmountAdditional += nAmountOutput;
     211             :     }
     212             :     // Adding other outputs can change the serialized size of the vout size hence + GetSizeOfCompactSizeDiff()
     213           2 :     unsigned int nBytes = GetBytesTotal() + nBytesAdditional + GetSizeOfCompactSizeDiff(vecOutputAmounts.size());
     214           2 :     return GetAmountLeft(GetAmountInitial(), GetAmountUsed() + nAmountAdditional, GetFee(nBytes)) >= 0;
     215           2 : }
     216             : 
     217         103 : CTransactionBuilderOutput* CTransactionBuilder::AddOutput(CAmount nAmountOutput)
     218             : {
     219         103 :     if (CouldAddOutput(nAmountOutput)) {
     220         101 :         LOCK(cs_outputs);
     221         101 :         vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, m_wallet, nAmountOutput));
     222         101 :         return vecOutputs.back().get();
     223         101 :     }
     224           2 :     return nullptr;
     225         103 : }
     226             : 
     227         120 : unsigned int CTransactionBuilder::GetBytesTotal() const
     228             : {
     229         120 :     LOCK(cs_outputs);
     230             :     // Adding other outputs can change the serialized size of the vout size hence + GetSizeOfCompactSizeDiff()
     231         120 :     return nBytesBase + vecOutputs.size() * nBytesOutput + ::GetSizeOfCompactSizeDiff(0, vecOutputs.size());
     232         120 : }
     233             : 
     234         107 : CAmount CTransactionBuilder::GetAmountLeft(const CAmount nAmountInitial, const CAmount nAmountUsed, const CAmount nFee)
     235             : {
     236         107 :     return nAmountInitial - nAmountUsed - nFee;
     237             : }
     238             : 
     239         118 : CAmount CTransactionBuilder::GetAmountUsed() const
     240             : {
     241         118 :     LOCK(cs_outputs);
     242        5276 :     return std::accumulate(vecOutputs.begin(), vecOutputs.end(), CAmount{0}, [](const CAmount& a, const auto& b){
     243        5158 :         return a + b->GetAmount();
     244             :     });
     245         118 : }
     246             : 
     247         120 : CAmount CTransactionBuilder::GetFee(unsigned int nBytes) const
     248             : {
     249         120 :     CAmount nFeeCalc = coinControl.m_feerate->GetFee(nBytes);
     250         120 :     CAmount nRequiredFee = GetRequiredFee(m_wallet, nBytes);
     251         120 :     if (nRequiredFee > nFeeCalc) {
     252           0 :         nFeeCalc = nRequiredFee;
     253           0 :     }
     254         120 :     if (nFeeCalc > m_wallet.m_default_max_tx_fee) {
     255           0 :         nFeeCalc = m_wallet.m_default_max_tx_fee;
     256           0 :     }
     257         120 :     return nFeeCalc;
     258             : }
     259             : 
     260         108 : int CTransactionBuilder::GetSizeOfCompactSizeDiff(size_t nAdd) const
     261             : {
     262         216 :     size_t nSize = WITH_LOCK(cs_outputs, return vecOutputs.size());
     263         108 :     unsigned int ret = ::GetSizeOfCompactSizeDiff(nSize, nSize + nAdd);
     264         108 :     assert(ret <= std::numeric_limits<int>::max());
     265         108 :     return int(ret);
     266             : }
     267             : 
     268           2 : bool CTransactionBuilder::IsDust(CAmount nAmount) const
     269             : {
     270           2 :     return ::IsDust(CTxOut(nAmount, ::GetScriptForDestination(tallyItem.txdest)), coinControl.m_discard_feerate.value());
     271           0 : }
     272             : 
     273           3 : bool CTransactionBuilder::Commit(bilingual_str& strResult)
     274             : {
     275           3 :     CAmount nFeeRet = 0;
     276           3 :     int nChangePosRet{RANDOM_CHANGE_POSITION};
     277             : 
     278             :     // Transform the outputs to the format CWallet::CreateTransaction requires
     279           3 :     std::vector<CRecipient> vecSend;
     280             :     {
     281           3 :         LOCK(cs_outputs);
     282           3 :         vecSend.reserve(vecOutputs.size());
     283         105 :         for (const auto& out : vecOutputs) {
     284         102 :             vecSend.push_back((CRecipient){out->GetScript(), out->GetAmount(), false});
     285             :         }
     286           3 :     }
     287             : 
     288           3 :     CTransactionRef tx;
     289             :     {
     290           3 :         LOCK2(m_wallet.cs_wallet, ::cs_main);
     291           3 :         auto ret = wallet::CreateTransaction(m_wallet, vecSend, nChangePosRet, coinControl);
     292           3 :         if (ret) {
     293           2 :             tx = ret->tx;
     294           2 :             nFeeRet = ret->fee;
     295           2 :             nChangePosRet = ret->change_pos;
     296           2 :         } else {
     297           1 :             strResult = util::ErrorString(ret);
     298           1 :             return false;
     299             :         }
     300           3 :     }
     301             : 
     302           2 :     CAmount nAmountLeft = GetAmountLeft();
     303           2 :     bool fDust = IsDust(nAmountLeft);
     304             :     // If there is a either remainder which is considered to be dust (will be added to fee in this case) or no amount left there should be no change output, return if there is a change output.
     305           2 :     if (nChangePosRet != RANDOM_CHANGE_POSITION && fDust) {
     306           0 :         strResult = Untranslated(strprintf("Unexpected change output %s at position %d", tx->vout[nChangePosRet].ToString(), nChangePosRet));
     307           0 :         return false;
     308             :     }
     309             : 
     310             :     // If there is a remainder which is not considered to be dust it should end up in a change output, return if not.
     311           2 :     if (nChangePosRet == RANDOM_CHANGE_POSITION && !fDust) {
     312           0 :         strResult = Untranslated(strprintf("Change output missing: %d", nAmountLeft));
     313           0 :         return false;
     314             :     }
     315             : 
     316           2 :     CAmount nFeeAdditional{0};
     317           2 :     unsigned int nBytesAdditional{0};
     318             : 
     319           2 :     if (fDust) {
     320           1 :         nFeeAdditional = nAmountLeft;
     321           1 :     } else {
     322             :         // Add a change output and GetSizeOfCompactSizeDiff(1) as another output can changes the serialized size of the vout size in CTransaction
     323           1 :         nBytesAdditional = nBytesOutput + GetSizeOfCompactSizeDiff(1);
     324             :     }
     325             : 
     326             :     // If the calculated fee does not match the fee returned by CreateTransaction aka if this check fails something is wrong!
     327           2 :     CAmount nFeeCalc = GetFee(GetBytesTotal() + nBytesAdditional) + nFeeAdditional;
     328           2 :     if (nFeeRet != nFeeCalc) {
     329           0 :         strResult = Untranslated(strprintf("Fee validation failed -> nFeeRet: %d, nFeeCalc: %d, nFeeAdditional: %d, nBytesAdditional: %d, %s", nFeeRet, nFeeCalc, nFeeAdditional, nBytesAdditional, ToString()));
     330           0 :         return false;
     331             :     }
     332             : 
     333             :     {
     334           2 :         LOCK2(m_wallet.cs_wallet, ::cs_main);
     335           2 :         m_wallet.CommitTransaction(tx, {}, {});
     336           2 :     }
     337             : 
     338           2 :     fKeepKeys = true;
     339             : 
     340           2 :     strResult = Untranslated(tx->GetHash().ToString());
     341             : 
     342           2 :     return true;
     343           3 : }
     344             : 
     345           0 : std::string CTransactionBuilder::ToString() const
     346             : {
     347           0 :     return strprintf("CTransactionBuilder(Amount initial: %d, Amount left: %d, Bytes base: %d, Bytes output: %d, Bytes total: %d, Amount used: %d, Outputs: %d, Fee rate: %d, Discard fee rate: %d, Fee: %d)",
     348           0 :         GetAmountInitial(),
     349           0 :         GetAmountLeft(),
     350           0 :         nBytesBase,
     351           0 :         nBytesOutput,
     352           0 :         GetBytesTotal(),
     353           0 :         GetAmountUsed(),
     354           0 :         CountOutputs(),
     355           0 :         coinControl.m_feerate->GetFeePerK(),
     356           0 :         coinControl.m_discard_feerate->GetFeePerK(),
     357           0 :         GetFee(GetBytesTotal()));
     358             : }

Generated by: LCOV version 1.16