Skip to content

MuLabPKU/PaST

Repository files navigation

PaST — Parametric Skill Transfer

Official code release for our ACL 2026 paper:

Knowledge is Not Enough: Injecting RL Skills for Continual Adaptation Pingzhi Tang, Yiding Wang, Muhan Zhang. Annual Meeting of the Association for Computational Linguistics (ACL), 2026. arXiv:2601.11258

PaST (Parametric Skill Transfer) extracts a domain-agnostic skill vector — the parameter delta produced by RL post-training on a source domain — and injects it into a target model that has only been lightly SFT-ed on a new domain. This recovers the reasoning skills RL would otherwise have to be repeated for each new domain, while keeping the freshly-learned target-domain knowledge intact.

Repository scope

This repository contains the code for two of the experimental settings in the paper:

Setting Location Paper section Description
LooGLE (long-context QA) repo root (this README) §5.1.2 Per-passage two-stage SFT + GRPO + skill-vector inheritance on the LooGLE benchmark.
SQuAD (closed-book knowledge incorporation, on top of the SEAL framework) SEAL/ §5.1.1 Implications-based SFT + GRPO with a judge reward, evaluated with delta-weight injection during test-time training. Built on SEAL (Zweiger et al., 2026).

The rest of this README documents the LooGLE pipeline. For the SQuAD setting see SEAL/README.md and SEAL/general-knowledge/README_GRPO_DELTA.md.


Table of contents

  1. Repository layout
  2. Setup
  3. Pipeline overview
  4. Step-by-step usage
  5. The PaST operator
  6. Evaluation
  7. Citation

Repository layout

PaST/
├── inherit_weight.py        # Core PaST operator: subtract base from RL model,
│                              add the resulting skill vector onto target models.
├── calc_delta_and_save.py   # Save (RL − base) deltas to a `.pth` file.
├── reward_function.py       # verl-compatible reward; queries a judge HTTP server.
├── reward_server.py         # FastAPI judge server (OpenAI / vLLM backend).
├── prompts.py               # All prompt templates (proposer / solver / judge).
├── utils.py                 # Long-text chunking and parsing helpers.
├── merge_lora.py            # Helper: merge a LoRA adapter into its base model.
│
├── prepare_data/            # Data generation
│   ├── generate_mixed_data.py     # Stage-1 SFT data (summary / recall / continue).
│   ├── generate_qa.py             # QA-pair generation for stage-2 SFT and GRPO.
│   ├── loogle_to_parquet.py       # Convert LooGLE jsonl → verl-format parquet.
│   ├── preprocess_squad.py        # SQuAD preprocessing.
│   ├── preprocess_narrativeqa.py  # NarrativeQA preprocessing.
│   ├── preprocess_gsm8k.py        # GSM8K preprocessing.
│   └── mix_data.py                # Concatenate per-passage parquet shards.
│
├── eval_loogle/             # LooGLE evaluation
│   ├── common.py              # Shared loaders (model / tokenizer / data).
│   ├── eval_with_remote.py    # Generate + judge in one process.
│   ├── generate_answers.py    # Generate answers only (decoupled from judging).
│   ├── score_answers.py       # Score generated answers locally.
│   ├── score_answers_with_server.py  # Score against the judge server.
│   └── judge_utils.py
│
├── analysis/                # Skill-vector analysis
│   ├── extract_delta.py     # Extract per-layer/per-tensor delta statistics.
│   ├── heatmap.py           # Visualize the delta as a heatmap.
│   └── save_cosine.py       # Cosine similarity between deltas across runs.
│
├── scripts/
│   ├── gen_data/            # Data-generation entry points.
│   │   ├── gen_sft_data.sh
│   │   ├── gen_rl_data.sh
│   │   └── process_loogle_eval.sh
│   ├── train/               # Training entry points.
│   │   ├── sft.sh             # Stage-1 + Stage-2 SFT for one passage.
│   │   ├── all_sft.sh         # Loop SFT over many passages.
│   │   ├── sft_iter.sh        # Iterative SFT loop.
│   │   ├── grpo.sh            # GRPO on one passage.
│   │   ├── grpo_single.sh     # GRPO with a different schedule.
│   │   ├── grpo_round10.sh    # Multi-round GRPO with PaST inheritance.
│   │   ├── start_judge_server.sh  # Launch the FastAPI judge server.
│   │   └── merge_data.sh
│   ├── eval/
│   │   └── eval_sft.sh
│   ├── eval_gsm8k.sh
│   ├── eval_with_context.sh
│   ├── merge_and_eval.sh    # PaST inheritance + LooGLE evaluation.
│   ├── merge_fsdp_to_hf.sh  # Convert FSDP checkpoints to HF format.
│   ├── run_train_pipeline.sh
│   ├── show_results.py      # Aggregate test_results.json into a markdown table.
│   └── test_merge.sh
│
├── debug/                   # Small inspection / sanity-check scripts.
├── loogle-clean/            # Cleaned LooGLE benchmark (data + reference scripts).
├── SEAL/                    # SQuAD experiment built on the SEAL framework
│                              (paper §5.1.1). See SEAL/README.md.
├── env.yaml                 # Conda environment (`past`).
├── flashattn_requirements.txt
└── README.md

Runtime / output directories (created by the scripts, all gitignored): data/, data-new/, checkpoints/, outputs/, wandb/, eval_results/, results_combined/, viz_output/.


Setup

1. Clone the repo

git clone https://github.com/MuLabPKU/PaST.git
cd PaST

2. Create the environment

conda env create -f env.yaml
conda activate past   # the env name in env.yaml
pip install -r flashattn_requirements.txt

3. Patch verl

Following verl#1296 and this fix, patch verl/trainer/ppo/reward.py so that custom reward functions are loaded by a stable, path-hashed module name. The full patched function:

import multiprocessing
import os
import hashlib
from functools import partial

import ray

from verl import DataProto
from verl.utils.reward_score import default_compute_score


def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
    merged_kwargs = {**kwargs, **extra_kwargs}
    return raw_fn(*args, **merged_kwargs)


def get_custom_reward_fn(config):
    import importlib.util
    import sys

    reward_fn_config = config.get("custom_reward_function") or {}
    file_path = reward_fn_config.get("path")
    if not file_path:
        return None

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Reward function file '{file_path}' not found.")

    module_name = f"custom_reward_{hashlib.md5(file_path.encode()).hexdigest()}"

    if module_name in sys.modules:
        module = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)
        try:
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
        except Exception as e:
            raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e

    function_name = reward_fn_config.get("name")
    if not hasattr(module, function_name):
        raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")

    print(f"using customized reward function '{function_name}' from '{file_path}'")
    raw_fn = getattr(module, function_name)
    reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
    return partial(_call_with_kwargs, raw_fn, reward_kwargs)

4. Configure secrets

The training scripts and the judge server need API keys. None of these should be committed.

# Weights & Biases (used by training scripts)
export WANDB_API_KEY=...

# OpenAI key for the judge server
export OPENAI_API_KEY=sk-...
# or, equivalently:
echo "sk-..." > openai_api_key.txt   # gitignored

Pipeline overview

For each LooGLE passage p we run:

  1. Stage-1 SFT — train on summary / recall / continue tasks built from the passage. This injects passage-specific knowledge.
  2. Stage-2 SFT — short SFT on QA pairs derived from the passage.
  3. Stage-RL (GRPO) — RL with a judge-model reward on QA prompts sampled from a separate set of source passages.
  4. PaST inheritance — compute the delta between the GRPO model and its pre-RL initialisation, then add it onto target models that have only been through stages 1+2.
  5. Evaluation — generate answers on held-out LooGLE QA and score them with the judge server.

Step-by-step usage

1. Generate data

Generate stage-1 (summary/recall/continue) and stage-2 (QA) SFT data:

# Args: <CUDA_VISIBLE_DEVICES> <passage range>
bash scripts/gen_data/gen_sft_data.sh 0 0-104

You can shard across GPUs:

bash scripts/gen_data/gen_sft_data.sh 0 0-25
bash scripts/gen_data/gen_sft_data.sh 1 26-50
# ...

Generate the prompt set used for RL:

bash scripts/gen_data/gen_rl_data.sh 0 100-104

Convert LooGLE evaluation data to verl parquet format:

bash scripts/gen_data/process_loogle_eval.sh

2. Two-stage SFT

# Edit `passage_id` at the top of the script first.
bash scripts/train/sft.sh

Or loop over many passages:

bash scripts/train/all_sft.sh

3. Start the judge server

GRPO and evaluation both need it.

# Make sure OPENAI_API_KEY is exported (or in openai_api_key.txt).
bash scripts/train/start_judge_server.sh
# Default port: 8123. Update JUDGE_SERVER_URL in reward_function.py if needed.

4. GRPO

# Edit `passage_id` at the top of the script first.
bash scripts/train/grpo.sh

5. Skill transfer (PaST)

After GRPO finishes on a source passage, copy the learned skill vector onto target passages that have only been SFT-ed:

python inherit_weight.py \
    --grpo_model   checkpoints/passage<src>/stage2-grpo/global_step_<N>/actor/huggingface \
    --base_model   checkpoints/passage<src>/stage2/global_step_<M> \
    --target_models \
        checkpoints/passage<tgt1>/stage2/global_step_<...> \
        checkpoints/passage<tgt2>/stage2/global_step_<...> \
    --output_dirs \
        checkpoints/passage<tgt1>/inherited \
        checkpoints/passage<tgt2>/inherited

scripts/merge_and_eval.sh and scripts/train/grpo_round10.sh show batched and iterative versions of this loop.

6. Evaluation

Single-process generate + judge:

CUDA_VISIBLE_DEVICES=0 python eval_loogle/eval_with_remote.py \
    --passage_idx 1 \
    --model_path  checkpoints/passage1/inherited \
    --n_runs      3

Decoupled (generate now, judge later):

python eval_loogle/generate_answers.py \
    --passage_idx 1 \
    --model_path  checkpoints/passage1/inherited \
    --output_file outputs/loogle/passage1.json

python eval_loogle/score_answers_with_server.py \
    --input_file outputs/loogle/passage1.json

Aggregate all test_results.json files into a markdown table:

python scripts/show_results.py

The PaST operator

inherit_weight.py is the entire method in <100 lines:

deltas[name] = grpo_model[name] - base_model[name]   # skill vector
target_model[name] += deltas[name]                   # skill injection

Both calls iterate over named_parameters(), run on CPU, and load the two source models once before applying the deltas to as many targets as you pass on the command line. calc_delta_and_save.py saves the same deltas as a single .pth for inspection.


Citation

@article{tang2026knowledge,
  title={Knowledge is Not Enough: Injecting RL Skills for Continual Adaptation},
  author={Tang, Pingzhi and Wang, Yiding and Zhang, Muhan},
  journal={arXiv preprint arXiv:2601.11258},
  year={2026}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors