{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}

{- | Threat model for detecting Self-Reference Injection vulnerabilities.

A Self-Reference Injection attack exploits validators that check "pay to address X"
where X comes from the datum. If an attacker can set X to the script's own address,
then the continuation output (which necessarily goes to the script) satisfies the
payment check for free.

== Consequences ==

1. __Free value extraction__: The attacker bypasses payment requirements.
   Instead of paying to the intended recipient, the "payment" is just the
   continuation output that was going to the script anyway.

2. __Protocol violation__: The semantic meaning of "pay to X" is violated.
   For example, in a "king of the hill" contract, the old king never receives
   their rightful payment when overthrown.

== Vulnerable Patterns ==

@
validator spend(datum: Datum, _redeemer: Void, ctx: ScriptContext) {
  // Check that SOME output pays current_beneficiary
  expect Some(out) = list.find(ctx.transaction.outputs, fn(o) {
    o.address == datum.current_beneficiary && o.value >= datum.min_payment
  })
  ...
}
@

If @datum.current_beneficiary@ can be set to the script's own address (either
through initialization or a state transition), the continuation output satisfies
this check automatically.

== Attack Flow ==

1. Set @current_beneficiary@ (or similar address field) to the script address
2. Create a transaction that spends the script UTxO
3. The continuation output satisfies the "payment" check
4. No real payment to the intended recipient occurs

== Mitigation ==

1. __Validate address is different__: Check that the beneficiary address is NOT
   the script's own address.

2. __Use credential type restrictions__: Only allow PubKeyCredential for
   beneficiary fields (not ScriptCredential).

3. __External payment validation__: Check that a SEPARATE output (not the
   continuation) pays the beneficiary.

This threat model:
1. Finds a script input (to identify the script address)
2. Finds continuation outputs with inline datums
3. Replaces address credentials in the datum with the script's own credential
4. Tests if the modified transaction still validates (vulnerability exists!)
-}
module Convex.ThreatModel.SelfReferenceInjection (
  -- * Threat models
  selfReferenceInjection,
  selfReferenceInjectionWith,

  -- * Datum transformation
  injectScriptCredential,
  isAddressLikeStructure,
) where

import Convex.ThreatModel
import Data.ByteString qualified as BS

{- | Check for Self-Reference Injection vulnerabilities.

For a transaction with script inputs and outputs containing inline datums:

* Identifies the script address from the spent script input
* Finds continuation outputs (script outputs with inline datums)
* Replaces address credentials in the datum with the script's own credential
* If the transaction still validates, the script is vulnerable to self-reference

This catches the "king of the hill" vulnerability where setting @current_king@
to the script address allows bypassing the payment requirement.

The attack works because:
1. Script checks "some output pays current_king at least X"
2. When current_king = script_address, the continuation output satisfies this
3. The attacker gets the "king" status without actually paying anyone
-}
selfReferenceInjection :: ThreatModel ()
selfReferenceInjection :: ThreatModel ()
selfReferenceInjection = Bool -> ThreatModel ()
selfReferenceInjectionWith Bool
False

{- | Check for Self-Reference Injection with configurable verbosity.

@
selfReferenceInjectionWith True  -- Verbose mode with detailed counterexamples
selfReferenceInjectionWith False -- Standard mode
@
-}
selfReferenceInjectionWith :: Bool -> ThreatModel ()
selfReferenceInjectionWith :: Bool -> ThreatModel ()
selfReferenceInjectionWith Bool
verbose = String -> ThreatModel () -> ThreatModel ()
forall a. String -> ThreatModel a -> ThreatModel a
Named String
"Self-Reference Injection" (ThreatModel () -> ThreatModel ())
-> ThreatModel () -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ do
  -- Get all inputs and outputs
  [Input]
inputs <- ThreatModel [Input]
getTxInputs
  [Output]
outputs <- ThreatModel [Output]
getTxOutputs

  -- Find script inputs (non-key address inputs)
  let scriptInputs :: [Input]
scriptInputs = (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Input -> Bool) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddressAny -> Bool
isKeyAddressAny (AddressAny -> Bool) -> (Input -> AddressAny) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> AddressAny
forall t. IsInputOrOutput t => t -> AddressAny
addressOf) [Input]
inputs

  -- Precondition: must have at least one script input
  ThreatModel () -> ThreatModel ()
forall a. ThreatModel a -> ThreatModel a
threatPrecondition (ThreatModel () -> ThreatModel ())
-> ThreatModel () -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ Bool -> ThreatModel ()
ensure (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Input] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Input]
scriptInputs)

  -- Pick a script input to get the script address
  Input
scriptInput <- [Input] -> ThreatModel Input
forall a. Show a => [a] -> ThreatModel a
pickAny [Input]
scriptInputs

  -- Extract the script hash from the script address
  let scriptAddr :: AddressAny
scriptAddr = Input -> AddressAny
forall t. IsInputOrOutput t => t -> AddressAny
addressOf Input
scriptInput

  ByteString
credBytes <- case AddressAny -> Maybe ByteString
extractScriptCredential AddressAny
scriptAddr of
    Maybe ByteString
Nothing -> String -> ThreatModel ByteString
forall a. String -> ThreatModel a
failPrecondition String
"Script output missing"
    Just ByteString
credBytes' -> ByteString -> ThreatModel ByteString
forall a. a -> ThreatModel a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
credBytes'

  -- Filter to script outputs with inline datums (continuation outputs)
  let continuationOutputs :: [Output]
continuationOutputs = (Output -> Bool) -> [Output] -> [Output]
forall a. (a -> Bool) -> [a] -> [a]
filter Output -> Bool
isScriptOutputWithInlineDatum [Output]
outputs

  -- Precondition: there must be at least one continuation output
  ThreatModel () -> ThreatModel ()
forall a. ThreatModel a -> ThreatModel a
threatPrecondition (ThreatModel () -> ThreatModel ())
-> ThreatModel () -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ Bool -> ThreatModel ()
ensure (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Output] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Output]
continuationOutputs)

  -- Pick a target output
  Output
target <- [Output] -> ThreatModel Output
forall a. Show a => [a] -> ThreatModel a
pickAny [Output]
continuationOutputs

  -- Extract the inline datum
  ScriptData
originalDatum <- case Output -> Maybe ScriptData
getInlineDatum Output
target of
    Maybe ScriptData
Nothing -> String -> ThreatModel ScriptData
forall a. String -> ThreatModel a
failPrecondition String
"Script output missing inline datum"
    Just ScriptData
originalDatum' -> ScriptData -> ThreatModel ScriptData
forall a. a -> ThreatModel a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScriptData
originalDatum'

  -- Build the script's credential as ScriptData
  -- ScriptCredential = Constr 1 [Bytes script_hash]
  let scriptCredData :: ScriptData
scriptCredData = Integer -> [ScriptData] -> ScriptData
ScriptDataConstructor Integer
1 [ByteString -> ScriptData
ScriptDataBytes ByteString
credBytes]

  -- Try to inject the script credential into address fields
  let modifiedDatum :: ScriptData
modifiedDatum = ScriptData -> ScriptData -> ScriptData
injectScriptCredential ScriptData
scriptCredData ScriptData
originalDatum

  -- Only proceed if something actually changed
  ThreatModel () -> ThreatModel ()
forall a. ThreatModel a -> ThreatModel a
threatPrecondition (ThreatModel () -> ThreatModel ())
-> ThreatModel () -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ Bool -> ThreatModel ()
ensure (ScriptData
modifiedDatum ScriptData -> ScriptData -> Bool
forall a. Eq a => a -> a -> Bool
/= ScriptData
originalDatum)

  Bool -> ThreatModel () -> ThreatModel ()
forall {f :: * -> *}. Applicative f => Bool -> f () -> f ()
when Bool
verbose (ThreatModel () -> ThreatModel ())
-> ThreatModel () -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ do
    String -> ThreatModel ()
counterexampleTM (String -> ThreatModel ()) -> String -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
paragraph
        [ String
"The transaction contains a script input at address"
        , Doc -> String
forall a. Show a => a -> String
show (AddressAny -> Doc
prettyAddress AddressAny
scriptAddr)
        , String
"and a continuation output at index"
        , TxIx -> String
forall a. Show a => a -> String
show (Output -> TxIx
outputIx Output
target)
        , String
"with an inline datum."
        ]

    String -> ThreatModel ()
counterexampleTM (String -> ThreatModel ()) -> String -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
paragraph
        [ String
"Testing if address fields in the datum can be replaced with"
        , String
"the script's own address while still passing validation."
        ]

  String -> ThreatModel ()
counterexampleTM (String -> ThreatModel ()) -> String -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$
    [String] -> String
paragraph
      [ String
"Self-reference injection: replaced address credentials in datum"
      , String
"with the script's own credential. If this validates, the script"
      , String
"doesn't prevent setting address fields to its own address,"
      , String
"allowing bypass of payment requirements."
      ]

  -- Try to validate with the modified datum
  TxModifier -> ThreatModel ()
shouldNotValidate (TxModifier -> ThreatModel ()) -> TxModifier -> ThreatModel ()
forall a b. (a -> b) -> a -> b
$ Output -> Datum -> TxModifier
forall t. IsInputOrOutput t => t -> Datum -> TxModifier
changeDatumOf Output
target (ScriptData -> Datum
toInlineDatum ScriptData
modifiedDatum)
 where
  when :: Bool -> f () -> f ()
when Bool
False f ()
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  when Bool
True f ()
action = f ()
action

{- | Extract the script credential (28 bytes) from a script address.

Returns Nothing for pubkey addresses.
-}
extractScriptCredential :: AddressAny -> Maybe BS.ByteString
extractScriptCredential :: AddressAny -> Maybe ByteString
extractScriptCredential AddressAny
addr =
  case AddressAny
addr of
    AddressShelley (ShelleyAddress Network
_ PaymentCredential
cred StakeReference
_) ->
      case PaymentCredential -> PaymentCredential
fromShelleyPaymentCredential PaymentCredential
cred of
        PaymentCredentialByScript ScriptHash
sh ->
          ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ScriptHash -> ByteString
forall a. SerialiseAsRawBytes a => a -> ByteString
serialiseToRawBytes ScriptHash
sh
        PaymentCredentialByKey Hash PaymentKey
_ ->
          Maybe ByteString
forall a. Maybe a
Nothing
    AddressByron Address ByronAddr
_ ->
      Maybe ByteString
forall a. Maybe a
Nothing

{- | Inject a script credential into all address-like structures in a datum.

An address in Plutus is encoded as:
@Constr 0 [credential, staking_credential]@

Where credential is either:
- @Constr 0 [Bytes pubkey_hash]@ (PubKeyCredential)
- @Constr 1 [Bytes script_hash]@ (ScriptCredential)

This function walks the datum and replaces all PubKeyCredential structures
with the given ScriptCredential, simulating an attacker who sets the
beneficiary address to the script's own address.
-}
injectScriptCredential :: ScriptData -> ScriptData -> ScriptData
injectScriptCredential :: ScriptData -> ScriptData -> ScriptData
injectScriptCredential ScriptData
newCred = ScriptData -> ScriptData
go
 where
  go :: ScriptData -> ScriptData
go (ScriptDataConstructor Integer
0 [ScriptData
cred, ScriptData
stakingCred])
    | ScriptData -> Bool
isCredentialLike ScriptData
cred =
        -- This looks like a Plutus Address (Constr 0 [credential, staking])
        -- Replace the credential with our script credential
        Integer -> [ScriptData] -> ScriptData
ScriptDataConstructor Integer
0 [ScriptData
newCred, ScriptData
stakingCred]
  go (ScriptDataConstructor Integer
idx [ScriptData]
fields) =
    Integer -> [ScriptData] -> ScriptData
ScriptDataConstructor Integer
idx ((ScriptData -> ScriptData) -> [ScriptData] -> [ScriptData]
forall a b. (a -> b) -> [a] -> [b]
map ScriptData -> ScriptData
go [ScriptData]
fields)
  go (ScriptDataList [ScriptData]
items) =
    [ScriptData] -> ScriptData
ScriptDataList ((ScriptData -> ScriptData) -> [ScriptData] -> [ScriptData]
forall a b. (a -> b) -> [a] -> [b]
map ScriptData -> ScriptData
go [ScriptData]
items)
  go (ScriptDataMap [(ScriptData, ScriptData)]
entries) =
    [(ScriptData, ScriptData)] -> ScriptData
ScriptDataMap [(ScriptData -> ScriptData
go ScriptData
k, ScriptData -> ScriptData
go ScriptData
v) | (ScriptData
k, ScriptData
v) <- [(ScriptData, ScriptData)]
entries]
  go ScriptData
other = ScriptData
other

  -- Check if a ScriptData looks like a Plutus Credential
  -- PubKeyCredential = Constr 0 [Bytes (28 bytes)]
  -- ScriptCredential = Constr 1 [Bytes (28 bytes)]
  isCredentialLike :: ScriptData -> Bool
isCredentialLike (ScriptDataConstructor Integer
n [ScriptDataBytes ByteString
bs])
    | Integer
n Integer -> [Integer] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Integer
0, Integer
1] Bool -> Bool -> Bool
&& ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
28 = Bool
True
  isCredentialLike ScriptData
_ = Bool
False

{- | Check if a ScriptData structure looks like a Plutus Address.

A Plutus Address is: @Constr 0 [credential, staking_credential]@

Where:
- credential is @Constr 0|1 [Bytes (28 bytes)]@
- staking_credential is @Constr 0|1 [...]@ or @Constr 0 []@ (None)
-}
isAddressLikeStructure :: ScriptData -> Bool
isAddressLikeStructure :: ScriptData -> Bool
isAddressLikeStructure (ScriptDataConstructor Integer
0 [ScriptData
cred, ScriptData
_stakingCred]) =
  ScriptData -> Bool
isCredentialLike ScriptData
cred
 where
  isCredentialLike :: ScriptData -> Bool
isCredentialLike (ScriptDataConstructor Integer
n [ScriptDataBytes ByteString
bs])
    | Integer
n Integer -> [Integer] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Integer
0, Integer
1] Bool -> Bool -> Bool
&& ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
28 = Bool
True
  isCredentialLike ScriptData
_ = Bool
False
isAddressLikeStructure ScriptData
_ = Bool
False

-- | Check if an output is a script output with an inline datum.
isScriptOutputWithInlineDatum :: Output -> Bool
isScriptOutputWithInlineDatum :: Output -> Bool
isScriptOutputWithInlineDatum Output
output =
  Bool -> Bool
not (AddressAny -> Bool
isKeyAddressAny (Output -> AddressAny
forall t. IsInputOrOutput t => t -> AddressAny
addressOf Output
output)) Bool -> Bool -> Bool
&& Output -> Bool
hasInlineDatum Output
output

-- | Check if an output has an inline datum.
hasInlineDatum :: Output -> Bool
hasInlineDatum :: Output -> Bool
hasInlineDatum Output
output =
  case TxOut CtxTx Era -> Datum
forall ctx. TxOut ctx Era -> TxOutDatum ctx Era
datumOfTxOut (Output -> TxOut CtxTx Era
outputTxOut Output
output) of
    TxOutDatumInline{} -> Bool
True
    Datum
_ -> Bool
False

-- | Extract the inline datum from an output if present.
getInlineDatum :: Output -> Maybe ScriptData
getInlineDatum :: Output -> Maybe ScriptData
getInlineDatum Output
output =
  case TxOut CtxTx Era -> Datum
forall ctx. TxOut ctx Era -> TxOutDatum ctx Era
datumOfTxOut (Output -> TxOut CtxTx Era
outputTxOut Output
output) of
    TxOutDatumInline BabbageEraOnwards Era
_ HashableScriptData
hashableData -> ScriptData -> Maybe ScriptData
forall a. a -> Maybe a
Just (HashableScriptData -> ScriptData
getScriptData HashableScriptData
hashableData)
    Datum
_ -> Maybe ScriptData
forall a. Maybe a
Nothing

-- | Convert a @ScriptData@ to an inline @Datum@ (TxOutDatum CtxTx Era).
toInlineDatum :: ScriptData -> Datum
toInlineDatum :: ScriptData -> Datum
toInlineDatum ScriptData
sd =
  BabbageEraOnwards Era -> HashableScriptData -> Datum
forall era ctx.
BabbageEraOnwards era -> HashableScriptData -> TxOutDatum ctx era
TxOutDatumInline BabbageEraOnwards Era
BabbageEraOnwardsConway (ScriptData -> HashableScriptData
unsafeHashableScriptData ScriptData
sd)