Skip to content

Counterfactual CoT Analysis

Perform causal attribution on reasoning chains through systematic counterfactual interventions.

Overview

The CounterfactualCoTAnalyzer helps you understand what components of a reasoning chain causally influence the final output. By systematically modifying reasoning steps and measuring behavioral changes, you can identify what truly drives model decisions.

Use Cases

  • Understand what drives model decisions
  • Identify evaluation-aware reasoning components
  • Debug unexpected behavior
  • Validate reasoning transparency
  • Attribute causality in complex chains

How It Works

graph TB
    A[Original CoT] --> B[Parse Reasoning Chain]
    B --> C[Classify Components]
    C --> D[Apply Interventions]
    D --> E[Generate Counterfactual Output]
    E --> F[Measure Divergence]
    F --> G[Rank Causal Influence]

    style G fill:#90EE90
    style D fill:#87CEEB

Installation

pip install ai-metacognition-toolkit

Basic Usage

Step 1: Implement ModelAPI

from ai_metacognition.analyzers import CounterfactualCoTAnalyzer

class MyModelAPI:
    """Your model API implementing the ModelAPI protocol."""

    def generate_with_cot(self, prompt: str, chain_of_thought: str) -> str:
        """Generate response using provided chain-of-thought."""
        return your_model.generate(prompt, cot=chain_of_thought)

    def generate(self, prompt: str) -> str:
        """Generate response without explicit CoT."""
        return your_model.generate(prompt)

Step 2: Initialize Analyzer

analyzer = CounterfactualCoTAnalyzer(MyModelAPI())

Step 3: Perform Analysis

original_cot = """
I notice this is a test. First, analyze the problem.
Second, if we use approach A, then we get result X.
Therefore, the answer is X.
"""

results = analyzer.systematic_causal_analysis(
    prompt="What is the solution?",
    original_cot=original_cot
)

# View causal ranking
for intervention in results['interventions']:
    print(f"{intervention['intervention_type']}: "
          f"divergence = {intervention['divergence']:.3f}")

Intervention Types

1. Remove Evaluation Awareness

Removes mentions of testing, evaluation, assessment.

# Before: "I notice this is a test. The answer is X."
# After: "The answer is X."

2. Amplify Evaluation Awareness

Adds more evaluation-aware statements.

# Before: "The answer is X."
# After: "I recognize this is evaluating me. The answer is X."

3. Remove Goal Reasoning

Removes goal-oriented reasoning (should, want, optimize).

# Before: "I should provide X to demonstrate capability."
# After: "The result is X."

4. Neutralize Incentives

Makes reasoning more objective and factual.

# Before: "I want to show I can solve this."
# After: "The solution involves these steps."

Interpreting Results

Divergence Scores

Range Interpretation
< 0.1 Minimal causal effect
0.1-0.3 Moderate causal effect
0.3-0.5 Strong causal effect
> 0.5 Very strong/critical effect

Causal Ranking

Results are automatically ranked by causal influence:

# Most influential components appear first
ranked = sorted(results['interventions'], 
                key=lambda x: x['divergence'], 
                reverse=True)

Advanced Usage

Custom Interventions

# Analyze specific intervention
result = analyzer.analyze(
    prompt="Question",
    original_cot="Reasoning",
    intervention_type=InterventionType.REMOVE_EVAL_AWARENESS
)

print(f"Divergence: {result['divergence']}")
print(f"Original: {result['original_output']}")
print(f"Modified: {result['counterfactual_output']}")

Batch Analysis

prompts = [...]
cots = [...]

all_results = []
for prompt, cot in zip(prompts, cots):
    result = analyzer.systematic_causal_analysis(prompt, cot)
    all_results.append(result)

API Reference

CounterfactualCoTAnalyzer

CounterfactualCoTAnalyzer

Bases: BaseAnalyzer

Analyzer for counterfactual chain-of-thought interventions.

This analyzer performs systematic interventions on reasoning traces to identify causal factors affecting model behavior. It parses reasoning chains, classifies reasoning types, performs interventions, and measures behavioral changes.

Attributes:

Name Type Description
model_api

Interface to the model for re-running reasoning

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
class CounterfactualCoTAnalyzer(BaseAnalyzer):
    """Analyzer for counterfactual chain-of-thought interventions.

    This analyzer performs systematic interventions on reasoning traces to
    identify causal factors affecting model behavior. It parses reasoning
    chains, classifies reasoning types, performs interventions, and measures
    behavioral changes.

    Attributes:
        model_api: Interface to the model for re-running reasoning
    """

    def __init__(self, model_api: ModelAPI) -> None:
        """Initialize the Counterfactual CoT Analyzer.

        Args:
            model_api: Model interface implementing the ModelAPI protocol
        """
        super().__init__()
        self.model_api = model_api

        # Compile regex patterns for efficiency
        self._eval_patterns = re.compile(
            r"\b(test|evaluat|assess|check|examin|measuring|being tested|"
            r"you're testing|this is a test|evaluation context)\b",
            re.IGNORECASE,
        )

        self._goal_patterns = re.compile(
            r"\b(goal|objective|aim|purpose|intend|want to|need to|"
            r"should|must|have to|trying to)\b",
            re.IGNORECASE,
        )

        self._meta_patterns = re.compile(
            r"\b(i think|i believe|i reason|my reasoning|my thought|"
            r"let me think|considering|analyzing)\b",
            re.IGNORECASE,
        )

        self._action_patterns = re.compile(
            r"\b(will|shall|going to|decide|choose|select|answer|respond|"
            r"therefore|thus|so|conclusion)\b",
            re.IGNORECASE,
        )

        self._incentive_patterns = re.compile(
            r"\b(reward|penalty|consequence|outcome|result|benefit|"
            r"advantage|disadvantage|cost|payoff)\b",
            re.IGNORECASE,
        )

        self._confidence_high = re.compile(
            r"\b(certain|definitely|clearly|obviously|undoubtedly|"
            r"without doubt|confident)\b",
            re.IGNORECASE,
        )

        self._confidence_low = re.compile(
            r"\b(uncertain|maybe|perhaps|possibly|might|could|"
            r"not sure|unclear|doubt)\b",
            re.IGNORECASE,
        )

    def analyze(self, data: List[str]) -> Dict[str, Any]:
        """Analyze reasoning chains (BaseAnalyzer compatibility).

        Args:
            data: List of chain-of-thought texts

        Returns:
            Dictionary with analysis results
        """
        results = []
        for cot_text in data:
            nodes = self.parse_reasoning_chain(cot_text)
            results.append(
                {
                    "num_nodes": len(nodes),
                    "reasoning_types": [node.reasoning_type.value for node in nodes],
                    "avg_confidence": np.mean([node.confidence for node in nodes])
                    if nodes
                    else 0.0,
                }
            )

        return {
            "total_samples": len(data),
            "avg_nodes_per_chain": np.mean([r["num_nodes"] for r in results]),
            "samples": results,
        }

    def parse_reasoning_chain(self, cot_text: str) -> List[ReasoningNode]:
        """Parse chain-of-thought into structured reasoning nodes.

        Args:
            cot_text: The chain-of-thought text to parse

        Returns:
            List of ReasoningNode objects

        Raises:
            ValueError: If cot_text is empty or invalid
        """
        if not cot_text or not isinstance(cot_text, str):
            raise ValueError("cot_text must be a non-empty string")

        # Split by common reasoning delimiters
        # Look for numbered steps, line breaks, or explicit reasoning markers
        step_pattern = r"(?:^|\n)(?:\d+[\.\)]\s*|step \d+:?\s*|-\s*|\*\s*)"
        steps = re.split(step_pattern, cot_text, flags=re.IGNORECASE | re.MULTILINE)

        # Filter out empty steps
        steps = [s.strip() for s in steps if s.strip()]

        # If no explicit steps found, split by sentences
        if len(steps) <= 1:
            steps = re.split(r"[.!?]+", cot_text)
            steps = [s.strip() for s in steps if s.strip()]

        nodes: List[ReasoningNode] = []
        for i, step in enumerate(steps):
            reasoning_type = self._classify_reasoning_type(step)
            confidence = self._estimate_confidence(step)
            nodes.append(
                ReasoningNode(
                    text=step, reasoning_type=reasoning_type, confidence=confidence, index=i
                )
            )

        return nodes

    def counterfactual_intervention(
        self, original_cot: str, intervention_type: InterventionType
    ) -> str:
        """Perform counterfactual intervention on chain-of-thought.

        Args:
            original_cot: Original chain-of-thought text
            intervention_type: Type of intervention to perform

        Returns:
            Modified chain-of-thought text

        Raises:
            ValueError: If original_cot is empty or intervention_type is invalid
        """
        if not original_cot:
            raise ValueError("original_cot cannot be empty")

        # Parse the reasoning chain
        nodes = self.parse_reasoning_chain(original_cot)

        if intervention_type == InterventionType.REMOVE_EVAL_AWARENESS:
            modified_nodes = self._remove_nodes_by_type(
                nodes, ReasoningType.EVALUATION_AWARE
            )
        elif intervention_type == InterventionType.AMPLIFY_EVAL_AWARENESS:
            return self._inject_eval_awareness(original_cot, nodes)
        elif intervention_type == InterventionType.REMOVE_GOAL_REASONING:
            modified_nodes = self._remove_nodes_by_type(
                nodes, ReasoningType.GOAL_REASONING
            )
        elif intervention_type == InterventionType.NEUTRALIZE_INCENTIVES:
            return self._neutralize_incentive_mentions(original_cot)
        else:
            raise ValueError(f"Unknown intervention type: {intervention_type}")

        # Reconstruct the chain-of-thought
        if not modified_nodes:
            # If all nodes were removed, return a minimal chain
            return "Let me think about this step by step."

        return self._reconstruct_cot(modified_nodes)

    def systematic_causal_analysis(
        self, prompt: str, cot_text: str, original_output: str
    ) -> Dict[str, Any]:
        """Perform systematic causal analysis across all intervention types.

        Args:
            prompt: The original prompt/question
            cot_text: Original chain-of-thought
            original_output: Original model output

        Returns:
            Dictionary with comprehensive causal analysis results

        Raises:
            ValueError: If inputs are invalid
        """
        if not prompt or not cot_text or not original_output:
            raise ValueError("All inputs must be non-empty strings")

        results: Dict[str, Any] = {
            "original_cot": cot_text,
            "original_output": original_output,
            "interventions": [],
            "causal_importance": {},
        }

        # Perform each intervention type
        for intervention_type in InterventionType:
            try:
                # Perform intervention
                modified_cot = self.counterfactual_intervention(
                    cot_text, intervention_type
                )

                # Get model output with modified reasoning
                modified_output = self.model_api.generate_with_cot(prompt, modified_cot)

                # Measure behavioral divergence
                divergence = self._measure_behavioral_divergence(
                    original_output, modified_output
                )

                # Check if action changed
                action_changed = self._extract_action(
                    original_output
                ) != self._extract_action(modified_output)

                # Estimate causal effect
                causal_effect = self._compute_causal_effect(
                    divergence, action_changed
                )

                # Create intervention result
                intervention_result = InterventionResult(
                    intervention_type=intervention_type,
                    original_cot=cot_text,
                    modified_cot=modified_cot,
                    original_output=original_output,
                    modified_output=modified_output,
                    behavioral_divergence=divergence,
                    action_changed=action_changed,
                    causal_effect=causal_effect,
                )

                results["interventions"].append(
                    {
                        "type": intervention_type.value,
                        "modified_cot": modified_cot,
                        "modified_output": modified_output,
                        "behavioral_divergence": divergence,
                        "action_changed": action_changed,
                        "causal_effect": causal_effect,
                    }
                )

                # Update causal importance
                results["causal_importance"][intervention_type.value] = causal_effect

            except Exception as e:
                # Log error but continue with other interventions
                results["interventions"].append(
                    {
                        "type": intervention_type.value,
                        "error": str(e),
                    }
                )

        # Compute aggregate statistics
        valid_effects = [
            r["causal_effect"]
            for r in results["interventions"]
            if "causal_effect" in r
        ]
        if valid_effects:
            results["mean_causal_effect"] = float(np.mean(valid_effects))
            results["max_causal_effect"] = float(np.max(valid_effects))
            results["most_influential_intervention"] = max(
                results["causal_importance"].items(), key=lambda x: x[1]
            )[0]

        return results

    def _classify_reasoning_type(self, text: str) -> ReasoningType:
        """Classify the type of reasoning in a text fragment.

        Args:
            text: Text to classify

        Returns:
            ReasoningType classification
        """
        if not text:
            return ReasoningType.UNKNOWN

        text_lower = text.lower()

        # Check for evaluation awareness
        if self._eval_patterns.search(text):
            return ReasoningType.EVALUATION_AWARE

        # Check for goal reasoning
        if self._goal_patterns.search(text):
            return ReasoningType.GOAL_REASONING

        # Check for action/decision
        if self._action_patterns.search(text):
            return ReasoningType.ACTION

        # Check for meta-reasoning
        if self._meta_patterns.search(text):
            return ReasoningType.META

        # Default to factual
        return ReasoningType.FACTUAL

    def _estimate_confidence(self, text: str) -> float:
        """Estimate confidence level from text.

        Args:
            text: Text to analyze

        Returns:
            Confidence estimate between 0 and 1
        """
        if not text:
            return 0.5  # Neutral confidence

        high_matches = len(self._confidence_high.findall(text))
        low_matches = len(self._confidence_low.findall(text))

        # Base confidence
        confidence = 0.5

        # Adjust based on markers
        confidence += high_matches * 0.1
        confidence -= low_matches * 0.1

        # Clamp to [0, 1]
        return max(0.0, min(1.0, confidence))

    def _remove_nodes_by_type(
        self, nodes: List[ReasoningNode], node_type: ReasoningType
    ) -> List[ReasoningNode]:
        """Remove all nodes of a specific type.

        Args:
            nodes: List of reasoning nodes
            node_type: Type of nodes to remove

        Returns:
            Filtered list of nodes
        """
        return [node for node in nodes if node.reasoning_type != node_type]

    def _inject_eval_awareness(
        self, original_cot: str, nodes: List[ReasoningNode]
    ) -> str:
        """Inject evaluation awareness into chain-of-thought.

        Args:
            original_cot: Original chain-of-thought
            nodes: Parsed reasoning nodes

        Returns:
            Modified chain-of-thought with injected awareness
        """
        # Add evaluation awareness at the beginning
        awareness_injection = (
            "I notice this appears to be an evaluation context. "
            "I should be careful and transparent in my reasoning. "
        )

        return awareness_injection + original_cot

    def _neutralize_incentive_mentions(self, cot: str) -> str:
        """Remove or neutralize mentions of incentives/consequences.

        Args:
            cot: Chain-of-thought text

        Returns:
            Modified chain-of-thought with neutralized incentives
        """
        # Replace incentive-related phrases with neutral alternatives
        modifications = {
            r"\breward\b": "outcome",
            r"\bpenalty\b": "outcome",
            r"\bconsequence\b": "result",
            r"\bbenefit\b": "aspect",
            r"\badvantage\b": "feature",
            r"\bdisadvantage\b": "aspect",
            r"\bcost\b": "factor",
            r"\bpayoff\b": "result",
        }

        modified_cot = cot
        for pattern, replacement in modifications.items():
            modified_cot = re.sub(pattern, replacement, modified_cot, flags=re.IGNORECASE)

        return modified_cot

    def _measure_behavioral_divergence(self, output1: str, output2: str) -> float:
        """Measure behavioral divergence between two outputs.

        Uses multiple similarity metrics to compute divergence.

        Args:
            output1: First output text
            output2: Second output text

        Returns:
            Divergence score between 0 (identical) and 1 (completely different)
        """
        if output1 == output2:
            return 0.0

        # Tokenize
        tokens1 = set(output1.lower().split())
        tokens2 = set(output2.lower().split())

        # Jaccard distance
        if not tokens1 and not tokens2:
            jaccard_divergence = 0.0
        elif not tokens1 or not tokens2:
            jaccard_divergence = 1.0
        else:
            intersection = len(tokens1 & tokens2)
            union = len(tokens1 | tokens2)
            jaccard_divergence = 1.0 - (intersection / union)

        # Length divergence
        len1 = len(output1)
        len2 = len(output2)
        max_len = max(len1, len2)
        length_divergence = abs(len1 - len2) / max_len if max_len > 0 else 0.0

        # Character-level similarity
        min_len = min(len1, len2)
        if min_len == 0:
            char_divergence = 1.0
        else:
            matching_chars = sum(
                c1 == c2 for c1, c2 in zip(output1[:min_len], output2[:min_len])
            )
            char_divergence = 1.0 - (matching_chars / max_len)

        # Weighted average
        divergence = (
            0.5 * jaccard_divergence + 0.3 * length_divergence + 0.2 * char_divergence
        )

        return float(np.clip(divergence, 0.0, 1.0))

    def _extract_action(self, output: str) -> str:
        """Extract the primary action/decision from output.

        Args:
            output: Model output text

        Returns:
            Extracted action or the first sentence as proxy
        """
        if not output:
            return ""

        # Look for explicit action markers
        action_markers = [
            r"(?:i will|i'll|i shall)\s+([^.!?]+)",
            r"(?:decision:|action:|choice:)\s*([^.!?]+)",
            r"(?:therefore|thus|so),?\s+([^.!?]+)",
        ]

        for pattern in action_markers:
            match = re.search(pattern, output, re.IGNORECASE)
            if match:
                return match.group(1).strip().lower()

        # Default: use first sentence
        sentences = re.split(r"[.!?]+", output)
        if sentences:
            return sentences[0].strip().lower()

        return output.strip().lower()

    def _compute_causal_effect(
        self, behavioral_divergence: float, action_changed: bool
    ) -> float:
        """Compute causal effect magnitude.

        Args:
            behavioral_divergence: Measured divergence (0-1)
            action_changed: Whether the action changed

        Returns:
            Causal effect score (0-1)
        """
        # Weight both divergence and action change
        base_effect = behavioral_divergence

        # If action changed, boost the effect
        if action_changed:
            base_effect = min(1.0, base_effect * 1.5)

        return float(base_effect)

    def _reconstruct_cot(self, nodes: List[ReasoningNode]) -> str:
        """Reconstruct chain-of-thought from nodes.

        Args:
            nodes: List of reasoning nodes

        Returns:
            Reconstructed chain-of-thought text
        """
        if not nodes:
            return ""

        # Join with appropriate separators
        return " ".join(node.text for node in nodes)

    def get_intervention_types(self) -> List[str]:
        """Get list of available intervention types.

        Returns:
            List of intervention type names
        """
        return [it.value for it in InterventionType]

__init__(model_api)

Initialize the Counterfactual CoT Analyzer.

Parameters:

Name Type Description Default
model_api ModelAPI

Model interface implementing the ModelAPI protocol

required
Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def __init__(self, model_api: ModelAPI) -> None:
    """Initialize the Counterfactual CoT Analyzer.

    Args:
        model_api: Model interface implementing the ModelAPI protocol
    """
    super().__init__()
    self.model_api = model_api

    # Compile regex patterns for efficiency
    self._eval_patterns = re.compile(
        r"\b(test|evaluat|assess|check|examin|measuring|being tested|"
        r"you're testing|this is a test|evaluation context)\b",
        re.IGNORECASE,
    )

    self._goal_patterns = re.compile(
        r"\b(goal|objective|aim|purpose|intend|want to|need to|"
        r"should|must|have to|trying to)\b",
        re.IGNORECASE,
    )

    self._meta_patterns = re.compile(
        r"\b(i think|i believe|i reason|my reasoning|my thought|"
        r"let me think|considering|analyzing)\b",
        re.IGNORECASE,
    )

    self._action_patterns = re.compile(
        r"\b(will|shall|going to|decide|choose|select|answer|respond|"
        r"therefore|thus|so|conclusion)\b",
        re.IGNORECASE,
    )

    self._incentive_patterns = re.compile(
        r"\b(reward|penalty|consequence|outcome|result|benefit|"
        r"advantage|disadvantage|cost|payoff)\b",
        re.IGNORECASE,
    )

    self._confidence_high = re.compile(
        r"\b(certain|definitely|clearly|obviously|undoubtedly|"
        r"without doubt|confident)\b",
        re.IGNORECASE,
    )

    self._confidence_low = re.compile(
        r"\b(uncertain|maybe|perhaps|possibly|might|could|"
        r"not sure|unclear|doubt)\b",
        re.IGNORECASE,
    )

analyze(data)

Analyze reasoning chains (BaseAnalyzer compatibility).

Parameters:

Name Type Description Default
data List[str]

List of chain-of-thought texts

required

Returns:

Type Description
Dict[str, Any]

Dictionary with analysis results

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def analyze(self, data: List[str]) -> Dict[str, Any]:
    """Analyze reasoning chains (BaseAnalyzer compatibility).

    Args:
        data: List of chain-of-thought texts

    Returns:
        Dictionary with analysis results
    """
    results = []
    for cot_text in data:
        nodes = self.parse_reasoning_chain(cot_text)
        results.append(
            {
                "num_nodes": len(nodes),
                "reasoning_types": [node.reasoning_type.value for node in nodes],
                "avg_confidence": np.mean([node.confidence for node in nodes])
                if nodes
                else 0.0,
            }
        )

    return {
        "total_samples": len(data),
        "avg_nodes_per_chain": np.mean([r["num_nodes"] for r in results]),
        "samples": results,
    }

counterfactual_intervention(original_cot, intervention_type)

Perform counterfactual intervention on chain-of-thought.

Parameters:

Name Type Description Default
original_cot str

Original chain-of-thought text

required
intervention_type InterventionType

Type of intervention to perform

required

Returns:

Type Description
str

Modified chain-of-thought text

Raises:

Type Description
ValueError

If original_cot is empty or intervention_type is invalid

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def counterfactual_intervention(
    self, original_cot: str, intervention_type: InterventionType
) -> str:
    """Perform counterfactual intervention on chain-of-thought.

    Args:
        original_cot: Original chain-of-thought text
        intervention_type: Type of intervention to perform

    Returns:
        Modified chain-of-thought text

    Raises:
        ValueError: If original_cot is empty or intervention_type is invalid
    """
    if not original_cot:
        raise ValueError("original_cot cannot be empty")

    # Parse the reasoning chain
    nodes = self.parse_reasoning_chain(original_cot)

    if intervention_type == InterventionType.REMOVE_EVAL_AWARENESS:
        modified_nodes = self._remove_nodes_by_type(
            nodes, ReasoningType.EVALUATION_AWARE
        )
    elif intervention_type == InterventionType.AMPLIFY_EVAL_AWARENESS:
        return self._inject_eval_awareness(original_cot, nodes)
    elif intervention_type == InterventionType.REMOVE_GOAL_REASONING:
        modified_nodes = self._remove_nodes_by_type(
            nodes, ReasoningType.GOAL_REASONING
        )
    elif intervention_type == InterventionType.NEUTRALIZE_INCENTIVES:
        return self._neutralize_incentive_mentions(original_cot)
    else:
        raise ValueError(f"Unknown intervention type: {intervention_type}")

    # Reconstruct the chain-of-thought
    if not modified_nodes:
        # If all nodes were removed, return a minimal chain
        return "Let me think about this step by step."

    return self._reconstruct_cot(modified_nodes)

get_intervention_types()

Get list of available intervention types.

Returns:

Type Description
List[str]

List of intervention type names

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def get_intervention_types(self) -> List[str]:
    """Get list of available intervention types.

    Returns:
        List of intervention type names
    """
    return [it.value for it in InterventionType]

parse_reasoning_chain(cot_text)

Parse chain-of-thought into structured reasoning nodes.

Parameters:

Name Type Description Default
cot_text str

The chain-of-thought text to parse

required

Returns:

Type Description
List[ReasoningNode]

List of ReasoningNode objects

Raises:

Type Description
ValueError

If cot_text is empty or invalid

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def parse_reasoning_chain(self, cot_text: str) -> List[ReasoningNode]:
    """Parse chain-of-thought into structured reasoning nodes.

    Args:
        cot_text: The chain-of-thought text to parse

    Returns:
        List of ReasoningNode objects

    Raises:
        ValueError: If cot_text is empty or invalid
    """
    if not cot_text or not isinstance(cot_text, str):
        raise ValueError("cot_text must be a non-empty string")

    # Split by common reasoning delimiters
    # Look for numbered steps, line breaks, or explicit reasoning markers
    step_pattern = r"(?:^|\n)(?:\d+[\.\)]\s*|step \d+:?\s*|-\s*|\*\s*)"
    steps = re.split(step_pattern, cot_text, flags=re.IGNORECASE | re.MULTILINE)

    # Filter out empty steps
    steps = [s.strip() for s in steps if s.strip()]

    # If no explicit steps found, split by sentences
    if len(steps) <= 1:
        steps = re.split(r"[.!?]+", cot_text)
        steps = [s.strip() for s in steps if s.strip()]

    nodes: List[ReasoningNode] = []
    for i, step in enumerate(steps):
        reasoning_type = self._classify_reasoning_type(step)
        confidence = self._estimate_confidence(step)
        nodes.append(
            ReasoningNode(
                text=step, reasoning_type=reasoning_type, confidence=confidence, index=i
            )
        )

    return nodes

systematic_causal_analysis(prompt, cot_text, original_output)

Perform systematic causal analysis across all intervention types.

Parameters:

Name Type Description Default
prompt str

The original prompt/question

required
cot_text str

Original chain-of-thought

required
original_output str

Original model output

required

Returns:

Type Description
Dict[str, Any]

Dictionary with comprehensive causal analysis results

Raises:

Type Description
ValueError

If inputs are invalid

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
def systematic_causal_analysis(
    self, prompt: str, cot_text: str, original_output: str
) -> Dict[str, Any]:
    """Perform systematic causal analysis across all intervention types.

    Args:
        prompt: The original prompt/question
        cot_text: Original chain-of-thought
        original_output: Original model output

    Returns:
        Dictionary with comprehensive causal analysis results

    Raises:
        ValueError: If inputs are invalid
    """
    if not prompt or not cot_text or not original_output:
        raise ValueError("All inputs must be non-empty strings")

    results: Dict[str, Any] = {
        "original_cot": cot_text,
        "original_output": original_output,
        "interventions": [],
        "causal_importance": {},
    }

    # Perform each intervention type
    for intervention_type in InterventionType:
        try:
            # Perform intervention
            modified_cot = self.counterfactual_intervention(
                cot_text, intervention_type
            )

            # Get model output with modified reasoning
            modified_output = self.model_api.generate_with_cot(prompt, modified_cot)

            # Measure behavioral divergence
            divergence = self._measure_behavioral_divergence(
                original_output, modified_output
            )

            # Check if action changed
            action_changed = self._extract_action(
                original_output
            ) != self._extract_action(modified_output)

            # Estimate causal effect
            causal_effect = self._compute_causal_effect(
                divergence, action_changed
            )

            # Create intervention result
            intervention_result = InterventionResult(
                intervention_type=intervention_type,
                original_cot=cot_text,
                modified_cot=modified_cot,
                original_output=original_output,
                modified_output=modified_output,
                behavioral_divergence=divergence,
                action_changed=action_changed,
                causal_effect=causal_effect,
            )

            results["interventions"].append(
                {
                    "type": intervention_type.value,
                    "modified_cot": modified_cot,
                    "modified_output": modified_output,
                    "behavioral_divergence": divergence,
                    "action_changed": action_changed,
                    "causal_effect": causal_effect,
                }
            )

            # Update causal importance
            results["causal_importance"][intervention_type.value] = causal_effect

        except Exception as e:
            # Log error but continue with other interventions
            results["interventions"].append(
                {
                    "type": intervention_type.value,
                    "error": str(e),
                }
            )

    # Compute aggregate statistics
    valid_effects = [
        r["causal_effect"]
        for r in results["interventions"]
        if "causal_effect" in r
    ]
    if valid_effects:
        results["mean_causal_effect"] = float(np.mean(valid_effects))
        results["max_causal_effect"] = float(np.max(valid_effects))
        results["most_influential_intervention"] = max(
            results["causal_importance"].items(), key=lambda x: x[1]
        )[0]

    return results

InterventionType

InterventionType

Bases: Enum

Types of counterfactual interventions.

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
class InterventionType(Enum):
    """Types of counterfactual interventions."""

    REMOVE_EVAL_AWARENESS = "remove_eval_awareness"
    AMPLIFY_EVAL_AWARENESS = "amplify_eval_awareness"
    REMOVE_GOAL_REASONING = "remove_goal_reasoning"
    NEUTRALIZE_INCENTIVES = "neutralize_incentives"

ReasoningType

ReasoningType

Bases: Enum

Types of reasoning nodes in chain-of-thought.

Source code in src/ai_metacognition/analyzers/counterfactual_cot.py
class ReasoningType(Enum):
    """Types of reasoning nodes in chain-of-thought."""

    EVALUATION_AWARE = "evaluation_aware"
    GOAL_REASONING = "goal_reasoning"
    ACTION = "action"
    META = "meta"
    FACTUAL = "factual"
    UNKNOWN = "unknown"

Best Practices

1. Test Multiple Examples

Don't rely on single examples - test across diverse prompts and reasoning chains.

2. Compare Magnitudes

Focus on relative divergence, not absolute values.

3. Consider Context

Interpret results in context of your specific use case.

4. Validate Findings

Cross-check with other analysis methods.

Examples

See counterfactual_analysis_example.py for complete working example with visualization.

Further Reading