Context-Parametric Inversion: Why Instruction Finetuning May Not Actually Improve Context Reliance

Sachin Goyal*    Christina Baek*    Zico Kolter    Aditi Raghunathan
*Equal Contribution   Carnegie Mellon University

TLDR

  1. LLM's are instruction finetuned to improve their ability to follow the user instruction and input context.
  2. However, state-of-the-art models still struggle to rely on the context, especially when it is not aligned with parametric knowledge. Why does this happen even after instruction finetuning?
  3. We observe an intriguing phenomenon: context reliance infact decreases with instruction finetuning, despite an initial expected increase. This decrease happens, while the performance on standard benchmarks keeps on increasing. Rather it starts quite early into the training.
  4. This is quite surprising as instruction finetuning is infact supposed to improve context reliance. We call this as context-parametric inversion.


Is there a Simple Explanation for this?

Negating some simple hypotheses

  1. Can this drop in context reliance with finetuning be attributed to memorization of facts? We observe that even when removing any overlap between finetuning and eval data, we observe a drop in context reliance.
  2. Model learns to rely on parametric knowledge broadly well outside the exact facts seen during finetuning.
  3. Another hypotheses could be lack of enough context based answering datapoints in instruction finetuning dataset.
  4. However, even when finetuning on a context-only subset of Alpaca (~30%), there is a drop in context reliance (red-curve).



A closer look at the Instruction Finetuning Datasets

  1. Context-Critical Datapoints: Context provides key information needed to answer the user query.
  2. Non-Context Critical Datapoints: Context overlaps with model's knowledge from pretraining.
  3. Parametric Datapoints: General fact recall datapoints.

What causes context-parametric inversion?

Both empirically and theoretically, we show that this drop in context reliance is due to the presence of non-context critical datapoints. In the early stages of training, context-critical points have a high loss, and drive the attention towards the context. In the later stages, the model leverages its pretrained knowledge to further reduce the loss on the non-context-critical points, shifting the attention away from the context.



Does Counterfactual Data Augmentation Help?

Our theoretical analysis naturally leads us to some potential mitigation strategies beyond filtering non-context critical datapoints. These strategies give some limited but insightful results.

  1. Counterfactual data augmentation, a widely used approach, improves context reliance only on tasks similar in type to the augmented data.
  2. For example, adding entity substituted QA counterfactual data improves performance on QA tasks only and doesn't generalize to other kind of context-parametric conflicts.
  3. QK Finetuning: Regularizing by updating only the "Query" and "Key" matrices enhances context reliance on certain tasks but may hurt performance on standard benchmarks, as value matrices can learn additional facts during finetuning.