Skip to content

feat: Add decoupled weight decay (AdamW) to Adam optimizer#1037

Open
AliAlimohammadi wants to merge 1 commit intoTheAlgorithms:masterfrom
AliAlimohammadi:add-adamw-weight-decay
Open

feat: Add decoupled weight decay (AdamW) to Adam optimizer#1037
AliAlimohammadi wants to merge 1 commit intoTheAlgorithms:masterfrom
AliAlimohammadi:add-adamw-weight-decay

Conversation

@AliAlimohammadi
Copy link
Copy Markdown
Contributor

Description

Extends the existing Adam optimizer in src/machine_learning/optimization/adam.rs to support decoupled weight decay (AdamW), as introduced in Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2019).

Rather than adding a separate AdamW struct, a single weight_decay: f64 field (defaulting to 0.0) is added to the existing Adam struct. When weight_decay is 0.0 the update is identical to standard Adam. When positive, the decay term $\lambda \cdot \theta_{t-1}$ is subtracted directly from the parameters after the adaptive gradient step, keeping it independent of the second moment scaling — the key correction AdamW makes over naive L2 regularisation inside Adam.

The step signature changes from step(&mut self, gradients: &[f64]) to step(&mut self, gradients: &[f64], params: &[f64]), since the decoupled decay term requires the current parameter values. All existing tests are updated accordingly (zero-initialised params preserve the original expected values).

Algorithm

Both variants share the same moment update rules:

$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$

$$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$

$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \qquad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$

They differ only in the parameter update step:

Adam — weight decay is absent (or equivalently, folded into $g_t$ as L2 regularisation, where it gets scaled down by $1/\sqrt{\hat{v}_t}$):

$$\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$

AdamW — weight decay is applied directly to $\theta_{t-1}$, fully decoupled from the adaptive scaling so its effect is constant regardless of gradient history:

$$\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \alpha \lambda \theta_{t-1}$$

Setting $\lambda = 0$ in the AdamW update recovers Adam exactly, both mathematically and in the implementation (verified by
test_adamw_step_weight_decay_zero_matches_adam).

Type of change

  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Adam::step gains a required params: &[f64] argument. Existing callers must pass the current parameter slice; passing &vec![0.0; n] restores the original behaviour exactly.

Checklist

  • I ran cargo clippy --all -- -D warnings just before my last commit and fixed any issue that was found.
  • I ran cargo fmt just before my last commit.
  • I ran cargo test just before my last commit and all tests passed.
  • I added my algorithm to the corresponding mod.rs file within its own folder, and in any parent folder(s).
  • I added my algorithm to DIRECTORY.md with the correct link.
  • I checked CONTRIBUTING.md and my code follows its guidelines.

@AliAlimohammadi
Copy link
Copy Markdown
Contributor Author

@siriak, this is ready to be merged.

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 73.83178% with 28 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.94%. Comparing base (e08f5a5) to head (96a8444).

Files with missing lines Patch % Lines
src/machine_learning/optimization/adam.rs 73.83% 28 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1037      +/-   ##
==========================================
- Coverage   96.01%   95.94%   -0.07%     
==========================================
  Files         392      392              
  Lines       29722    29806      +84     
==========================================
+ Hits        28537    28597      +60     
- Misses       1185     1209      +24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants