diff --git a/train_mamba.py b/train_mamba.py index ab39bb3..8485d26 100644 --- a/train_mamba.py +++ b/train_mamba.py @@ -35,7 +35,7 @@ def run(args): per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, optim=args.optim, - output_dir="mamba-chat", + output_dir=args.output_dir, logging_steps=50, save_steps=500, ), @@ -44,6 +44,7 @@ def run(args): trainer.train() + trainer.save_model(output_dir=f"{args.output_dir}/complete") if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -55,6 +56,7 @@ def run(args): parser.add_argument("--optim", type=str, default="adamw_torch") parser.add_argument("--data_path", type=str, default="./data/ultrachat_small.jsonl") parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--output_dir", type=str, default="mamba-chat") args = parser.parse_args() run(args) diff --git a/trainer/mamba_trainer.py b/trainer/mamba_trainer.py index 77708ef..766add8 100644 --- a/trainer/mamba_trainer.py +++ b/trainer/mamba_trainer.py @@ -17,7 +17,7 @@ def compute_loss(self, model, inputs, return_outputs=False): return lm_loss - def save_model(self, output_dir, _internal_call): + def save_model(self, output_dir, _internal_call: bool = False): if not os.path.exists(output_dir): os.makedirs(output_dir)