diff --git a/README.md b/README.md index b8d4767..571e666 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ A unified, extensible framework for text classification with categorical variabl - **Unified yet highly customizable**: - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer. - Text embedding is split into two composable stages: **`TokenEmbedder`** (token → per-token vectors, with optional self-attention) and **`SentenceEmbedder`** (aggregation: mean / first / last / label attention). Combine them with `CategoricalVariableNet` and `ClassificationHead` — all are `torch.nn.Module`. - - The `TextClassificationModel` class assembles these components and can be extended for custom behavior. + - **Two architecture paths**: use `ModelConfig` + the `torchTextClassifiers` constructor for the standard `TextClassificationModel` (zero boilerplate), or build any `nn.Module` you like and pass it to `torchTextClassifiers.from_model()` for full control. The `contrib` sub-package ships ready-made custom architectures (e.g. `MultiLevelTextClassificationModel` for multi-task classification) as reference implementations. - **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks. - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code: @@ -56,6 +56,7 @@ See the [examples/](examples/) directory for: - Mixed features (text + categorical) - Advanced training configurations - Prediction and explainability +- [Multi-level classification](examples/multilevel_example.py) — custom architecture via `from_model` and `contrib` ## šŸ“„ License diff --git a/docs/source/architecture/overview.md b/docs/source/architecture/overview.md index 737fa28..e0b5e84 100644 --- a/docs/source/architecture/overview.md +++ b/docs/source/architecture/overview.md @@ -599,6 +599,98 @@ predictions = classifier.predict(new_texts) - Don't need custom architecture - Want simplicity over control +## Two Architecture Paths + +torchTextClassifiers offers two ways to build a classifier, covering different +levels of customisation: + +### Path 1 — Standard architecture (ModelConfig) + +The `torchTextClassifiers` constructor accepts a `ModelConfig` and builds a +`TextClassificationModel` for you automatically. This covers the vast majority +of use cases: single-task binary, multi-class, or multi-label classification, +with or without categorical variables, with or without self-attention and label +attention. + +```python +from torchTextClassifiers import torchTextClassifiers, ModelConfig + +classifier = torchTextClassifiers( + tokenizer=tokenizer, + model_config=ModelConfig(embedding_dim=128, num_classes=5), +) +classifier.train(texts, labels, training_config) +predictions = classifier.predict(new_texts) +``` + +You never instantiate `TextClassificationModel` directly; `ModelConfig` is the +only knob you need. + +### Path 2 — Custom architecture (from_model) + +When `TextClassificationModel` cannot express what you need — multiple +classification heads, shared encoders across tasks, or any other topology — +build your own `nn.Module` and wrap it with `torchTextClassifiers.from_model`. +The wrapper then provides the same `predict` / `save` / `load` interface around +your model. + +```python +import torch.nn as nn +from torchTextClassifiers import torchTextClassifiers + +class MyModel(nn.Module): + num_classes = 3 + categorical_variable_net = None # or a CategoricalVariableNet instance + + def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): + ... + return logits # (batch, num_classes) — raw logits, not softmaxed + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=MyModel(), +) +``` + +**Required interface for custom models:** + +| Requirement | Details | +|---|---| +| `forward(input_ids, attention_mask, categorical_vars=None, **kwargs)` | Exact positional names; extra kwargs are ignored | +| Returns raw logits | `torch.Tensor` of shape `(batch, num_classes)`, or `list[torch.Tensor]` for multi-task | +| `num_classes` attribute | `int` for single-task; `list[int]` for multi-task | +| `categorical_variable_net` attribute | A `CategoricalVariableNet` instance, or `None` | + +### contrib — reference custom architectures + +The `torchTextClassifiers.contrib` sub-package ships example architectures that +follow the `from_model` interface and can be used directly or as starting points: + +| Class | Purpose | +|---|---| +| `MultiLevelTextClassificationModel` | Multi-task classifier: one shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task | +| `MultiLevelCrossEntropyLoss` | Weighted cross-entropy averaged across tasks | + +```python +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=[se_level1, se_level2, se_level3], + classification_heads=[head1, head2, head3], + categorical_variable_net=cat_net, +) +classifier = torchTextClassifiers.from_model(tokenizer=tokenizer, pytorch_model=model) +``` + +See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py) +for a complete working script. + +--- + ## For Advanced Users ### Direct PyTorch Usage diff --git a/docs/source/tutorials/custom_model.md b/docs/source/tutorials/custom_model.md new file mode 100644 index 0000000..c8a505a --- /dev/null +++ b/docs/source/tutorials/custom_model.md @@ -0,0 +1,215 @@ +# Custom Architectures with from_model + +**Difficulty:** Advanced | **Time:** 30 minutes + +## When to use this + +The standard `torchTextClassifiers` constructor + `ModelConfig` covers most +single-task classification needs. Use `from_model` when you need something the +standard architecture cannot express: + +- **Multiple classification heads** (multi-task / hierarchical labels) +- **Shared encoders** across several outputs +- **Custom combination logic** between text and categorical embeddings +- **Any other topology** that does not fit a single linear pipeline + +--- + +## Required interface + +Your custom model must satisfy three contracts so that the wrapper's `predict`, +`save`, and `load` methods work correctly. + +### 1. `forward` signature + +```python +def forward( + self, + input_ids: torch.Tensor, # (batch, seq_len) — Long + attention_mask: torch.Tensor, # (batch, seq_len) — int + categorical_vars: torch.Tensor, # (batch, n_cats) — Long, may be None + **kwargs, # ignored by the wrapper +) -> torch.Tensor | list[torch.Tensor]: + ... +``` + +- The argument **names must match exactly** — the wrapper calls the model with + keyword arguments from the dataloader collate function. +- The return value must be **raw logits** (not softmaxed). + - Single task → `torch.Tensor` of shape `(batch, num_classes)` + - Multi-task → `list[torch.Tensor]`, one tensor per task + +### 2. `num_classes` attribute + +```python +model.num_classes # int (single task) +model.num_classes # list[int] (multi-task — one entry per task head) +``` + +### 3. `categorical_variable_net` attribute + +```python +model.categorical_variable_net # CategoricalVariableNet | None +``` + +Set this to `None` if your model does not use categorical features. When it is +not `None` the wrapper reads +`categorical_variable_net.categorical_vocabulary_sizes` to configure data +encoding. + +--- + +## Minimal example — single-task custom model + +```python +import torch +import torch.nn as nn +from torchTextClassifiers import torchTextClassifiers +from torchTextClassifiers.model.components import TokenEmbedder, TokenEmbedderConfig +from torchTextClassifiers.tokenizers import WordPieceTokenizer + +class MyClassifier(nn.Module): + def __init__(self, vocab_size: int, num_classes: int): + super().__init__() + self.token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=vocab_size, embedding_dim=64, padding_idx=0, + )) + self.pool = lambda x, mask: (x * mask.unsqueeze(-1)).sum(1) / mask.sum(1, keepdim=True) + self.head = nn.Linear(64, num_classes) + + # Required attributes + self.num_classes = num_classes + self.categorical_variable_net = None # no categorical features + + def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): + out = self.token_embedder(input_ids, attention_mask) + sentence = self.pool(out["token_embeddings"], attention_mask.float()) + return self.head(sentence) # (batch, num_classes) — raw logits + +tokenizer = WordPieceTokenizer(vocab_size=5000) +tokenizer.train(texts) + +model = MyClassifier(vocab_size=tokenizer.vocab_size, num_classes=3) + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, +) +classifier.train(texts, labels, training_config) +predictions = classifier.predict(new_texts) +``` + +--- + +## Multi-task example — contrib architecture + +For multi-task classification the `contrib` sub-package provides ready-made +classes that follow the interface above. + +```python +from torchTextClassifiers import torchTextClassifiers, TrainingConfig +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) +from torchTextClassifiers.model.components import ( + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, SentenceEmbedderConfig, + TokenEmbedder, TokenEmbedderConfig, +) +from torchTextClassifiers.value_encoder import ValueEncoder + +# Assume tokenizer, value_encoder, and model_config are already built. +# value_encoder.num_classes is a list[int] — one count per task level. + +token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=tokenizer.vocab_size, + embedding_dim=64, + padding_idx=tokenizer.padding_idx, +)) +cat_net = CategoricalVariableNet( + categorical_vocabulary_sizes=value_encoder.vocabulary_sizes, + categorical_embedding_dims=8, + text_embedding_dim=64, +) + +sentence_embedders = [] +classification_heads = [] +for n_cls in value_encoder.num_classes: + sentence_embedders.append(SentenceEmbedder(SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig(n_head=2, num_classes=n_cls, embedding_dim=64), + ))) + classification_heads.append(ClassificationHead(input_dim=64 + cat_net.output_dim, num_classes=1)) + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=sentence_embedders, + classification_heads=classification_heads, + categorical_variable_net=cat_net, +) + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, + value_encoder=value_encoder, +) + +training_config = TrainingConfig( + num_epochs=10, + batch_size=32, + lr=1e-3, + raw_categorical_inputs=True, + loss=MultiLevelCrossEntropyLoss(num_classes=list(value_encoder.num_classes)), +) +classifier.train(X_train, y_train, training_config) +predictions = classifier.predict(X_test) +``` + +`predictions` is a dict with one key per task level. + +See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py) +for the full runnable script. + +--- + +## contrib reference + +| Class | Description | +|---|---| +| `MultiLevelTextClassificationModel` | Shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task | +| `MultiLevelCrossEntropyLoss` | Per-task cross-entropy, optionally weighted by `num_classes` | + +```python +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) +``` + +These classes are reference implementations — use them directly or as a +starting point for your own architecture. + +--- + +## Saving and loading + +`save` and `load` work the same way regardless of which path was used. Custom +models are serialised as a pickle of the model structure plus a separate +state-dict file; the `_custom_model` flag in the checkpoint tells `load` which +strategy to use. + +```python +classifier.save("my_classifier/") +loaded = torchTextClassifiers.load("my_classifier/") +``` + +--- + +## Next steps + +- **Architecture overview**: {doc}`../architecture/overview` — component reference and design philosophy +- **API reference**: {doc}`../api/wrapper` — full `torchTextClassifiers` API +- **contrib source**: `torchTextClassifiers/contrib/multilevel.py` diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 845c221..160f53c 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -10,6 +10,7 @@ multiclass_classification mixed_features explainability multilabel_classification +custom_model ``` ## Overview @@ -116,6 +117,21 @@ Assign multiple labels to each text sample for complex classification scenarios. **Difficulty:** Advanced | **Time:** 30 minutes ::: +:::{grid-item-card} {fas}`puzzle-piece` Custom Architectures +:link: custom_model +:link-type: doc + +Plug any PyTorch model into the torchTextClassifiers wrapper via `from_model`. + +**What you'll learn:** +- When to go beyond `TextClassificationModel` +- The required `forward` / `num_classes` / `categorical_variable_net` interface +- Using `contrib` classes as reference implementations +- Multi-task classification with `MultiLevelTextClassificationModel` + +**Difficulty:** Advanced | **Time:** 30 minutes +::: + :::: ## Learning Path @@ -130,6 +146,8 @@ graph LR C --> F[Multilabel Classification] D --> E[Explainability] F --> E + D --> G[Custom Architectures] + F --> G style A fill:#e3f2fd style B fill:#bbdefb @@ -137,6 +155,7 @@ graph LR style D fill:#64b5f6 style E fill:#1976d2 style F fill:#42a5f5 + style G fill:#0d47a1 ``` 1. **Start with**: {doc}`../getting_started/quickstart` - Get familiar with the basics diff --git a/examples/advanced_training.py b/examples/advanced_training.py index 03c48a7..82b5364 100644 --- a/examples/advanced_training.py +++ b/examples/advanced_training.py @@ -7,7 +7,6 @@ """ import os -import random import warnings import numpy as np @@ -17,13 +16,14 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -35,10 +35,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("āš™ļø Advanced Training Configuration Example") @@ -46,7 +43,7 @@ def main(): # Create a larger dataset for demonstrating advanced training print("šŸ“ Creating training dataset...") - + # Generate more diverse training data positive_samples = [ "Excellent product with outstanding quality and performance.", @@ -58,9 +55,9 @@ def main(): "Fantastic features and user-friendly interface design provided.", "Outstanding durability and reliability in daily usage.", "Impressive performance and excellent build quality throughout.", - "Wonderful experience from purchase to delivery service." + "Wonderful experience from purchase to delivery service.", ] - + negative_samples = [ "Terrible product quality, completely disappointed with purchase.", "Poor customer service and slow delivery times experienced.", @@ -71,30 +68,34 @@ def main(): "Horrible user experience and confusing interface design.", "Broke after few days of normal usage.", "Poor value for money, better alternatives available.", - "Disappointing performance and unreliable functionality shown." + "Disappointing performance and unreliable functionality shown.", ] - + # Combine and create arrays X_train = np.array(positive_samples + negative_samples) y_train = np.array([1] * len(positive_samples) + [0] * len(negative_samples)) - + # Validation data - X_val = np.array([ - "Good product with decent quality for the price.", - "Not satisfied with the purchase, poor quality.", - "Excellent service and great product quality.", - "Disappointed with the product performance results." - ]) + X_val = np.array( + [ + "Good product with decent quality for the price.", + "Not satisfied with the purchase, poor quality.", + "Excellent service and great product quality.", + "Disappointed with the product performance results.", + ] + ) y_val = np.array([1, 0, 1, 0]) - + # Test data - X_test = np.array([ - "Outstanding product with amazing features!", - "Terrible quality, complete waste of money.", - "Great value and excellent customer support." - ]) + X_test = np.array( + [ + "Outstanding product with amazing features!", + "Terrible quality, complete waste of money.", + "Great value and excellent customer support.", + ] + ) y_test = np.array([1, 0, 1]) - + print(f"Training samples: {len(X_train)}") print(f"Validation samples: {len(X_val)}") print(f"Test samples: {len(X_test)}") @@ -109,15 +110,9 @@ def main(): # Example 1: Basic training with default settings print("\nšŸŽÆ Example 1: Basic training with default settings...") - model_config = ModelConfig( - embedding_dim=100, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=100, num_classes=2) - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") training_config = TrainingConfig( @@ -126,45 +121,38 @@ def main(): lr=1e-3, patience_early_stopping=5, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) result = classifier.predict(X_test) basic_predictions = result["prediction"].squeeze().numpy() basic_accuracy = (basic_predictions == y_test).mean() print(f"āœ… Basic training completed! Accuracy: {basic_accuracy:.3f}") - + # Example 2: Advanced training with custom Lightning trainer parameters print("\nšŸš€ Example 2: Advanced training with custom parameters...") # Create a new classifier for comparison - advanced_model_config = ModelConfig( - embedding_dim=100, - num_classes=2 - ) + advanced_model_config = ModelConfig(embedding_dim=100, num_classes=2) advanced_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=advanced_model_config + tokenizer=tokenizer, model_config=advanced_model_config ) print("āœ… Advanced classifier created successfully!") # Custom trainer parameters for advanced features advanced_trainer_params = { - 'accelerator': 'auto', # Use GPU if available, else CPU - 'precision': 32, # Use 32-bit precision - 'gradient_clip_val': 1.0, # Gradient clipping - 'accumulate_grad_batches': 2, # Gradient accumulation - 'deterministic': True, # For reproducible results - 'enable_progress_bar': True, # Show progress bar - 'log_every_n_steps': 5, # Log every 5 steps + "accelerator": "auto", # Use GPU if available, else CPU + "precision": 32, # Use 32-bit precision + "gradient_clip_val": 1.0, # Gradient clipping + "accumulate_grad_batches": 2, # Gradient accumulation + "deterministic": True, # For reproducible results + "enable_progress_bar": True, # Show progress bar + "log_every_n_steps": 5, # Log every 5 steps } advanced_training_config = TrainingConfig( @@ -173,33 +161,32 @@ def main(): lr=1e-3, patience_early_stopping=7, num_workers=0, - trainer_params=advanced_trainer_params + trainer_params=advanced_trainer_params, ) advanced_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=advanced_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) advanced_result = advanced_classifier.predict(X_test) advanced_predictions = advanced_result["prediction"].squeeze().numpy() advanced_accuracy = (advanced_predictions == y_test).mean() print(f"āœ… Advanced training completed! Accuracy: {advanced_accuracy:.3f}") - + # Example 3: Training with CPU-only (useful for small datasets or debugging) print("\nšŸ’» Example 3: CPU-only training...") cpu_model_config = ModelConfig( embedding_dim=64, # Smaller embedding for faster CPU training - num_classes=2 + num_classes=2, ) - cpu_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=cpu_model_config - ) + cpu_classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=cpu_model_config) print("āœ… CPU classifier created successfully!") cpu_training_config = TrainingConfig( @@ -208,44 +195,40 @@ def main(): lr=1e-3, patience_early_stopping=3, num_workers=0, # No multiprocessing for CPU - trainer_params={'deterministic': True, 'accelerator': 'cpu'} + trainer_params={"deterministic": True, "accelerator": "cpu"}, ) cpu_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=cpu_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) cpu_result = cpu_classifier.predict(X_test) cpu_predictions = cpu_result["prediction"].squeeze().numpy() cpu_accuracy = (cpu_predictions == y_test).mean() print(f"āœ… CPU training completed! Accuracy: {cpu_accuracy:.3f}") - + # Example 4: Custom training with specific Lightning callbacks print("\nšŸ”§ Example 4: Training with custom callbacks...") - custom_model_config = ModelConfig( - embedding_dim=128, - num_classes=2 - ) + custom_model_config = ModelConfig(embedding_dim=128, num_classes=2) - custom_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=custom_model_config - ) + custom_classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=custom_model_config) print("āœ… Custom classifier created successfully!") # Custom trainer with specific monitoring and checkpointing custom_trainer_params = { - 'max_epochs': 25, - 'enable_progress_bar': True, - 'log_every_n_steps': 1, - 'check_val_every_n_epoch': 2, # Validate every 2 epochs - 'enable_checkpointing': True, - 'enable_model_summary': True, - 'deterministic': True, + "max_epochs": 25, + "enable_progress_bar": True, + "log_every_n_steps": 1, + "check_val_every_n_epoch": 2, # Validate every 2 epochs + "enable_checkpointing": True, + "enable_model_summary": True, + "deterministic": True, } custom_training_config = TrainingConfig( @@ -254,21 +237,23 @@ def main(): lr=1e-3, patience_early_stopping=8, num_workers=0, - trainer_params=custom_trainer_params + trainer_params=custom_trainer_params, ) custom_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=custom_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) custom_result = custom_classifier.predict(X_test) custom_predictions = custom_result["prediction"].squeeze().numpy() custom_accuracy = (custom_predictions == y_test).mean() print(f"āœ… Custom training completed! Accuracy: {custom_accuracy:.3f}") - + # Compare all training approaches print("\nšŸ“Š Training Comparison Results:") print("-" * 50) @@ -276,35 +261,35 @@ def main(): print(f"Advanced training: {advanced_accuracy:.3f}") print(f"CPU-only training: {cpu_accuracy:.3f}") print(f"Custom training: {custom_accuracy:.3f}") - + # Find best performing model results = { - 'Basic': basic_accuracy, - 'Advanced': advanced_accuracy, - 'CPU-only': cpu_accuracy, - 'Custom': custom_accuracy + "Basic": basic_accuracy, + "Advanced": advanced_accuracy, + "CPU-only": cpu_accuracy, + "Custom": custom_accuracy, } best_method = max(results, key=results.get) print(f"\nšŸ† Best performing method: {best_method} (Accuracy: {results[best_method]:.3f})") - + # Demonstrate prediction with best model print(f"\nšŸ”® Making predictions with {best_method.lower()} model...") best_classifier = { - 'Basic': classifier, - 'Advanced': advanced_classifier, - 'CPU-only': cpu_classifier, - 'Custom': custom_classifier + "Basic": classifier, + "Advanced": advanced_classifier, + "CPU-only": cpu_classifier, + "Custom": custom_classifier, }[best_method] - + predictions = best_classifier.predict(X_test) print("Test predictions:") for i, (text, pred, true) in enumerate(zip(X_test, predictions, y_test)): sentiment = "Positive" if pred == 1 else "Negative" correct = "āœ…" if pred == true else "āŒ" - print(f"{i+1}. {correct} {sentiment}: {text[:50]}...") + print(f"{i + 1}. {correct} {sentiment}: {text[:50]}...") print("\nšŸŽ‰ Advanced training example completed successfully!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/basic_classification.py b/examples/basic_classification.py index 695115d..9da2362 100644 --- a/examples/basic_classification.py +++ b/examples/basic_classification.py @@ -5,12 +5,10 @@ text classification using the Wrapper. """ -import os -import random import warnings import numpy as np -import torch + from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer @@ -18,10 +16,7 @@ def main(): # Suppress PyTorch Lightning batch_size inference warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸš€ Basic Text Classification Example") @@ -29,77 +24,85 @@ def main(): # Create sample data print("šŸ“ Creating sample data...") - X_train = np.array([ - "I love this product! It's amazing and works perfectly.", - "This is terrible. Worst purchase ever made.", - "Great quality and fast shipping. Highly recommend!", - "Poor quality, broke after one day. Very disappointed.", - "Excellent customer service and great value for money.", - "Overpriced and doesn't work as advertised.", - "Perfect! Exactly what I was looking for.", - "Waste of money. Should have read reviews first.", - "Outstanding product with excellent build quality.", - "Cheap plastic, feels like it will break soon.", - "Absolutely fantastic! Exceeded all my expectations.", - "Horrible experience. Customer service was rude and unhelpful.", - "Best purchase I've made this year. Five stars!", - "Defective item arrived. Packaging was also damaged.", - "Super impressed with the performance and durability.", - "Total disappointment. Doesn't match the description at all.", - "Wonderful product! My whole family loves it.", - "Avoid at all costs. Complete waste of time and money.", - "Remarkable quality for the price. Very satisfied!", - "Broke within a week. Clearly poor manufacturing.", - "Exceptional value! Would definitely buy again.", - "Misleading photos. Product looks nothing like advertised.", - "Works like a charm. Installation was easy too.", - "Returned it immediately. Not worth even half the price.", - "Beautiful design and sturdy construction. Love it!", - "Arrived late and damaged. Very frustrating experience.", - "Top-notch quality! Highly recommend to everyone.", - "Uncomfortable and poorly made. Regret buying this.", - "Perfect fit and great finish. Couldn't be happier!", - "Stopped working after two uses. Complete junk." - ]) - - y_train = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) # 1=positive, 0=negative - + X_train = np.array( + [ + "I love this product! It's amazing and works perfectly.", + "This is terrible. Worst purchase ever made.", + "Great quality and fast shipping. Highly recommend!", + "Poor quality, broke after one day. Very disappointed.", + "Excellent customer service and great value for money.", + "Overpriced and doesn't work as advertised.", + "Perfect! Exactly what I was looking for.", + "Waste of money. Should have read reviews first.", + "Outstanding product with excellent build quality.", + "Cheap plastic, feels like it will break soon.", + "Absolutely fantastic! Exceeded all my expectations.", + "Horrible experience. Customer service was rude and unhelpful.", + "Best purchase I've made this year. Five stars!", + "Defective item arrived. Packaging was also damaged.", + "Super impressed with the performance and durability.", + "Total disappointment. Doesn't match the description at all.", + "Wonderful product! My whole family loves it.", + "Avoid at all costs. Complete waste of time and money.", + "Remarkable quality for the price. Very satisfied!", + "Broke within a week. Clearly poor manufacturing.", + "Exceptional value! Would definitely buy again.", + "Misleading photos. Product looks nothing like advertised.", + "Works like a charm. Installation was easy too.", + "Returned it immediately. Not worth even half the price.", + "Beautiful design and sturdy construction. Love it!", + "Arrived late and damaged. Very frustrating experience.", + "Top-notch quality! Highly recommend to everyone.", + "Uncomfortable and poorly made. Regret buying this.", + "Perfect fit and great finish. Couldn't be happier!", + "Stopped working after two uses. Complete junk.", + ] + ) + + y_train = np.array( + [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0] + ) # 1=positive, 0=negative + # Validation data - X_val = np.array([ - "Good product, satisfied with purchase.", - "Not worth the money, poor quality.", - "Really happy with this purchase. Great item!", - "Disappointed with the quality. Expected better.", - "Solid product that does what it promises.", - "Don't waste your money on this. Very poor.", - "Impressive quality and quick delivery.", - "Malfunctioned right out of the box. Terrible." - ]) + X_val = np.array( + [ + "Good product, satisfied with purchase.", + "Not worth the money, poor quality.", + "Really happy with this purchase. Great item!", + "Disappointed with the quality. Expected better.", + "Solid product that does what it promises.", + "Don't waste your money on this. Very poor.", + "Impressive quality and quick delivery.", + "Malfunctioned right out of the box. Terrible.", + ] + ) y_val = np.array([1, 0, 1, 0, 1, 0, 1, 0]) - + # Test data - X_test = np.array([ - "This is an amazing product with great features!", - "Completely disappointed with this purchase.", - "Excellent build quality and works as expected.", - "Not recommended. Had issues from day one.", - "Fantastic product! Worth every penny.", - "Failed to meet basic expectations. Very poor.", - "Love it! Exactly as described and high quality.", - "Cheap materials and sloppy construction. Avoid.", - "Superb performance and easy to use. Highly satisfied!", - "Unreliable and frustrating. Should have bought elsewhere." - ]) + X_test = np.array( + [ + "This is an amazing product with great features!", + "Completely disappointed with this purchase.", + "Excellent build quality and works as expected.", + "Not recommended. Had issues from day one.", + "Fantastic product! Worth every penny.", + "Failed to meet basic expectations. Very poor.", + "Love it! Exactly as described and high quality.", + "Cheap materials and sloppy construction. Avoid.", + "Superb performance and easy to use. Highly satisfied!", + "Unreliable and frustrating. Should have bought elsewhere.", + ] + ) y_test = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) - + print(f"Training samples: {len(X_train)}") print(f"Validation samples: {len(X_val)}") print(f"Test samples: {len(X_test)}") - + # Create and train tokenizer print("\nšŸ—ļø Creating and training WordPiece tokenizer...") tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128) - + # Train tokenizer on the training corpus training_corpus = X_train.tolist() tokenizer.train(training_corpus) @@ -107,17 +110,11 @@ def main(): # Create model configuration print("\nšŸ”§ Creating model configuration...") - model_config = ModelConfig( - embedding_dim=50, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=2) # Create classifier print("\nšŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") print(classifier) # Train the model @@ -130,13 +127,10 @@ def main(): num_workers=0, # Use 0 for simple examples to avoid multiprocessing issues ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Make predictions print("\nšŸ”® Making predictions...") result = classifier.predict(X_test) @@ -149,19 +143,19 @@ def main(): # Calculate accuracy accuracy = (predictions == y_test).mean() print(f"Test accuracy: {accuracy:.3f}") - + # Show detailed results print("\nšŸ“Š Detailed Results:") print("-" * 40) for i, (text, pred, true) in enumerate(zip(X_test, predictions, y_test)): sentiment = "Positive" if pred == 1 else "Negative" correct = "āœ…" if pred == true else "āŒ" - print(f"{i+1}. {correct} Predicted: {sentiment}") + print(f"{i + 1}. {correct} Predicted: {sentiment}") print(f" Text: {text[:50]}...") print() - + print("\nšŸŽ‰ Example completed successfully!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/multiclass_classification.py b/examples/multiclass_classification.py index 92e5d48..e50063f 100644 --- a/examples/multiclass_classification.py +++ b/examples/multiclass_classification.py @@ -7,7 +7,6 @@ """ import os -import random import warnings import numpy as np @@ -17,13 +16,14 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -35,10 +35,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸŽ­ Multi-class Text Classification Example") @@ -46,54 +43,76 @@ def main(): # Create multi-class sample data (3 classes: 0=negative, 1=neutral, 2=positive) print("šŸ“ Creating multi-class sentiment data...") - X_train = np.array([ - # Negative examples (class 0) - "This product is terrible and I hate it completely.", - "Worst purchase ever. Total waste of money.", - "Absolutely awful quality. Very disappointed.", - "Poor service and terrible product quality.", - "I regret buying this. Complete failure.", - - # Neutral examples (class 1) - "The product is okay, nothing special though.", - "It works but could be better designed.", - "Average quality for the price point.", - "Not bad but not great either.", - "It's fine, meets basic expectations.", - - # Positive examples (class 2) - "Excellent product! Highly recommended!", - "Amazing quality and great customer service.", - "Perfect! Exactly what I was looking for.", - "Outstanding value and excellent performance.", - "Love it! Will definitely buy again." - ]) - - y_train = np.array([0, 0, 0, 0, 0, # negative - 1, 1, 1, 1, 1, # neutral - 2, 2, 2, 2, 2]) # positive - + X_train = np.array( + [ + # Negative examples (class 0) + "This product is terrible and I hate it completely.", + "Worst purchase ever. Total waste of money.", + "Absolutely awful quality. Very disappointed.", + "Poor service and terrible product quality.", + "I regret buying this. Complete failure.", + # Neutral examples (class 1) + "The product is okay, nothing special though.", + "It works but could be better designed.", + "Average quality for the price point.", + "Not bad but not great either.", + "It's fine, meets basic expectations.", + # Positive examples (class 2) + "Excellent product! Highly recommended!", + "Amazing quality and great customer service.", + "Perfect! Exactly what I was looking for.", + "Outstanding value and excellent performance.", + "Love it! Will definitely buy again.", + ] + ) + + y_train = np.array( + [ + 0, + 0, + 0, + 0, + 0, # negative + 1, + 1, + 1, + 1, + 1, # neutral + 2, + 2, + 2, + 2, + 2, + ] + ) # positive + # Validation data - X_val = np.array([ - "Bad quality, not recommended.", # negative - "It's okay, does the job.", # neutral - "Great product, very satisfied!" # positive - ]) + X_val = np.array( + [ + "Bad quality, not recommended.", # negative + "It's okay, does the job.", # neutral + "Great product, very satisfied!", # positive + ] + ) y_val = np.array([0, 1, 2]) - + # Test data - X_test = np.array([ - "This is absolutely horrible!", - "It's an average product, nothing more.", - "Fantastic! Love every aspect of it!", - "Really poor design and quality.", - "Works well, good value for money.", - "Outstanding product with amazing features!" - ]) + X_test = np.array( + [ + "This is absolutely horrible!", + "It's an average product, nothing more.", + "Fantastic! Love every aspect of it!", + "Really poor design and quality.", + "Works well, good value for money.", + "Outstanding product with amazing features!", + ] + ) y_test = np.array([0, 1, 2, 0, 1, 2]) - + print(f"Training samples: {len(X_train)}") - print(f"Class distribution: Negative={sum(y_train==0)}, Neutral={sum(y_train==1)}, Positive={sum(y_train==2)}") + print( + f"Class distribution: Negative={sum(y_train == 0)}, Neutral={sum(y_train == 1)}, Positive={sum(y_train == 2)}" + ) # Create and train tokenizer print("\nšŸ—ļø Creating and training WordPiece tokenizer...") @@ -106,15 +125,12 @@ def main(): print("\nšŸ”§ Creating model configuration...") model_config = ModelConfig( embedding_dim=64, - num_classes=3 # 3 classes for sentiment (negative, neutral, positive) + num_classes=3, # 3 classes for sentiment (negative, neutral, positive) ) # Create classifier print("\nšŸ”Ø Creating multi-class classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") # Train the model @@ -125,16 +141,13 @@ def main(): lr=1e-3, patience_early_stopping=7, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Make predictions print("\nšŸ”® Making predictions...") result = classifier.predict(X_test) @@ -145,10 +158,10 @@ def main(): # Calculate accuracy accuracy = (predictions == y_test).mean() print(f"Test accuracy: {accuracy:.3f}") - + # Define class names for better output class_names = ["Negative", "Neutral", "Positive"] - + # Show detailed results print("\nšŸ“Š Detailed Results:") print("-" * 60) @@ -160,15 +173,17 @@ def main(): if correct: correct_predictions += 1 status = "āœ…" if correct else "āŒ" - - print(f"{i+1}. {status} Predicted: {predicted_sentiment}, True: {true_sentiment}") + + print(f"{i + 1}. {status} Predicted: {predicted_sentiment}, True: {true_sentiment}") print(f" Text: {text}") print() - - print(f"Final Accuracy: {correct_predictions}/{len(X_test)} = {correct_predictions/len(X_test):.3f}") - - + + print( + f"Final Accuracy: {correct_predictions}/{len(X_test)} = {correct_predictions / len(X_test):.3f}" + ) + print("\nšŸŽ‰ Multi-class example completed successfully!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/multilevel_example.py b/examples/multilevel_example.py new file mode 100644 index 0000000..f6b347f --- /dev/null +++ b/examples/multilevel_example.py @@ -0,0 +1,219 @@ +from typing import cast + +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder + +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.contrib import ( + MultiLevelCrossEntropyLoss, + MultiLevelTextClassificationModel, +) +from torchTextClassifiers.dataset import TextClassificationDataset +from torchTextClassifiers.model import TextClassificationModule +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, +) +from torchTextClassifiers.tokenizers import WordPieceTokenizer +from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder + +sample_text_data = [ + "This is a positive example", + "This is a negative example", + "Another positive case", + "Another negative case", + "Good example here", + "Bad example here", +] + +categorical_data = np.array( + [ + ["cat", "red"], + ["dog", "blue"], + ["cat", "red"], + ["dog", "blue"], + ["cat", "red"], + ["dog", "blue"], + ] +) + +labels_level_1 = np.array(["positive", "negative", "positive", "negative", "positive", "neutral"]) +labels_level_2 = np.array(["good", "bad", "good", "bad", "good", "bad"]) +labels_level3 = np.array(["A", "B", "D", "B", "C", "B"]) + + +df = pd.DataFrame( + { + "text": sample_text_data, + "category": categorical_data[:, 0], + "color": categorical_data[:, 1], + "label_level_1": labels_level_1, # You can switch to labels_level_2 or labels_level_3 for testing + "label_level_2": labels_level_2, + "label_level_3": labels_level3, + } +) +vocab_size = 10 +tokenizer = WordPieceTokenizer(vocab_size, output_dim=50) +tokenizer.train(sample_text_data) + +encoders = {} +# category : DictEncoder (ours) +feature = "category" +mapping = {val: idx for idx, val in enumerate(df[feature].unique())} +encoders[feature] = DictEncoder(mapping) + +# color: LabelEncoder (sklearn) +le = LabelEncoder() +le.fit(df["color"]) +encoders["color"] = le + +feature = "label_level_1" +le_label = LabelEncoder() +le_label.fit(df[feature]) + +feature = "label_level_2" +le_label_2 = LabelEncoder() +le_label_2.fit(df[feature]) + +feature = "label_level_3" +le_label_3 = DictEncoder({val: idx for idx, val in enumerate(df[feature].unique())}) + +label_encoder = [le_label, le_label_2, le_label_3] +# OR you can also use DictEncoder +# dict_mapping = {val: idx for idx, val in enumerate(df[feature].unique())} +# label_encoder = DictEncoder(dict_mapping) + +value_encoder = ValueEncoder(label_encoder, encoders) + + +model_config = ModelConfig( + embedding_dim=10, + categorical_embedding_dims=5, + n_heads_label_attention=2, + num_classes=value_encoder.num_classes, + attention_config=AttentionConfig(n_layers=2, n_head=5, n_kv_head=5, positional_encoding=False), + aggregation_method=None, +) +training_config = TrainingConfig( + num_epochs=1, + batch_size=6, + lr=1e-3, + raw_categorical_inputs=True, +) + +train_dataset = TextClassificationDataset( + texts=df["text"].values, + categorical_variables=value_encoder.transform( + df[["category", "color"]].values + ), # None if no cat vars + tokenizer=tokenizer, + labels=value_encoder.transform_labels( + df[["label_level_1", "label_level_2", "label_level_3"]].values + ), # None if no labels +) +train_dataloader = train_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=False, + **training_config.dataloader_params if training_config.dataloader_params else {}, +) +batch = next(iter(train_dataloader)) + + +token_embedder_config = TokenEmbedderConfig( + vocab_size=tokenizer.vocab_size, + embedding_dim=model_config.embedding_dim, + padding_idx=tokenizer.padding_idx, + attention_config=model_config.attention_config, +) +token_embedder = TokenEmbedder( + token_embedder_config=token_embedder_config, +) +categorical_var_net = CategoricalVariableNet( + categorical_vocabulary_sizes=value_encoder.vocabulary_sizes, + categorical_embedding_dims=model_config.categorical_embedding_dims, + text_embedding_dim=model_config.embedding_dim, +) + +all_sentence_embedders = [] +all_classification_heads = [] + +for num_classes in value_encoder.num_classes: # ty:ignore[not-iterable] + sentence_embedder_config = SentenceEmbedderConfig( + label_attention_config=LabelAttentionConfig( + n_head=model_config.n_heads_label_attention, + num_classes=num_classes, + embedding_dim=model_config.embedding_dim, + ), + aggregation_method=model_config.aggregation_method, + ) + + sentence_embedder = SentenceEmbedder(sentence_embedder_config=sentence_embedder_config) + all_sentence_embedders.append(sentence_embedder) + + classif_head_input_dim = model_config.embedding_dim + if categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: + classif_head_input_dim += categorical_var_net.output_dim + + # because we use LabelAttention, the sentence embedder outputs a (num_classes, embedding_dim) tensor, and the classification head should output a single logit per class (i.e. num_classes=1) + classification_head = ClassificationHead(input_dim=classif_head_input_dim, num_classes=1) + all_classification_heads.append(classification_head) + + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=all_sentence_embedders, + classification_heads=all_classification_heads, + categorical_variable_net=categorical_var_net, +) + +module = TextClassificationModule( + model=model, + loss=MultiLevelCrossEntropyLoss(), + optimizer=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + scheduler=None, + scheduler_params=None, +) + +print(model.num_classes) + +batch = next(iter(train_dataloader)) +print(batch["labels"].shape) +outputs = model(**batch) +print(f"Outputs shapes: {[output.shape for output in outputs]}") + + +ttc = torchTextClassifiers.from_model( + tokenizer=tokenizer, pytorch_model=model, value_encoder=value_encoder +) + +training_config = TrainingConfig( + num_epochs=1, + batch_size=6, + lr=1e-3, + raw_categorical_inputs=True, + loss=MultiLevelCrossEntropyLoss(num_classes=cast(list[int], value_encoder.num_classes)), +) + +ttc.train( + X_train=df[["text", "category", "color"]].values, + y_train=df[["label_level_1", "label_level_2", "label_level_3"]].values, + training_config=training_config, +) + + +print( + ttc.predict( + X_test=df[["text", "category", "color"]].values, + ) +) diff --git a/examples/simple_explainability_example.py b/examples/simple_explainability_example.py index 20cc1d5..1febde1 100644 --- a/examples/simple_explainability_example.py +++ b/examples/simple_explainability_example.py @@ -14,7 +14,6 @@ from torchTextClassifiers.tokenizers import WordPieceTokenizer from torchTextClassifiers.utilities.plot_explainability import ( map_attributions_to_char, - map_attributions_to_word, ) @@ -23,8 +22,8 @@ def main(): SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -36,64 +35,94 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸ” Simple Explainability Example") # Enhanced training data with more diverse examples - X_train = np.array([ - # Positive examples - "I love this product", - "Great quality and excellent service", - "Amazing design and fantastic performance", - "Outstanding value for money", - "Excellent customer support team", - "Love the innovative features", - "Perfect solution for my needs", - "Highly recommend this item", - "Superb build quality", - "Wonderful experience overall", - "Great value and fast delivery", - "Excellent product with amazing results", - "Love this fantastic design", - "Perfect quality and great price", - "Amazing customer service experience", - - # Negative examples - "This is terrible quality", - "Poor design and cheap materials", - "Awful experience with this product", - "Terrible customer service response", - "Completely disappointing purchase", - "Poor quality and overpriced item", - "Awful build quality issues", - "Terrible value for money", - "Disappointing performance results", - "Poor service and bad experience", - "Awful design and cheap feel", - "Terrible product with many issues", - "Disappointing quality and poor value", - "Bad experience with customer support", - "Poor construction and awful materials" - ]) - - y_train = np.array([ - # Positive labels (1) - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - # Negative labels (0) - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ]) - - X_val = np.array([ - "Good product with decent quality", - "Bad quality and poor service", - "Excellent value and great design", - "Terrible experience and awful quality" - ]) + X_train = np.array( + [ + # Positive examples + "I love this product", + "Great quality and excellent service", + "Amazing design and fantastic performance", + "Outstanding value for money", + "Excellent customer support team", + "Love the innovative features", + "Perfect solution for my needs", + "Highly recommend this item", + "Superb build quality", + "Wonderful experience overall", + "Great value and fast delivery", + "Excellent product with amazing results", + "Love this fantastic design", + "Perfect quality and great price", + "Amazing customer service experience", + # Negative examples + "This is terrible quality", + "Poor design and cheap materials", + "Awful experience with this product", + "Terrible customer service response", + "Completely disappointing purchase", + "Poor quality and overpriced item", + "Awful build quality issues", + "Terrible value for money", + "Disappointing performance results", + "Poor service and bad experience", + "Awful design and cheap feel", + "Terrible product with many issues", + "Disappointing quality and poor value", + "Bad experience with customer support", + "Poor construction and awful materials", + ] + ) + + y_train = np.array( + [ + # Positive labels (1) + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + # Negative labels (0) + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ) + + X_val = np.array( + [ + "Good product with decent quality", + "Bad quality and poor service", + "Excellent value and great design", + "Terrible experience and awful quality", + ] + ) y_val = np.array([1, 0, 1, 0]) # Create and train tokenizer @@ -105,17 +134,11 @@ def main(): # Create model configuration print("\nšŸ”§ Creating model configuration...") - model_config = ModelConfig( - embedding_dim=50, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=2) # Create classifier print("\nšŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") # Train the model @@ -126,28 +149,25 @@ def main(): lr=1e-3, patience_early_stopping=5, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Test examples with different sentiments test_texts = [ "This product is amazing!", "Poor quality and terrible service", "Great value for money", "Completely disappointing and awful experience", - "Love this excellent design" + "Love this excellent design", ] - + print(f"\nšŸ” Testing explainability on {len(test_texts)} examples:") print("=" * 60) - + for i, test_text in enumerate(test_texts, 1): print(f"\nšŸ“ Example {i}:") print(f"Text: '{test_text}'") @@ -159,7 +179,9 @@ def main(): # Extract prediction prediction = result["prediction"][0][0].item() confidence = result["confidence"][0][0].item() - print(f"Prediction: {'Positive' if prediction == 1 else 'Negative'} (confidence: {confidence:.4f})") + print( + f"Prediction: {'Positive' if prediction == 1 else 'Negative'} (confidence: {confidence:.4f})" + ) # Extract attributions and mapping info attributions = result["attributions"][0][0] # shape: (seq_len,) @@ -170,7 +192,7 @@ def main(): char_attributions = map_attributions_to_char( attributions.unsqueeze(0), # Add batch dimension: (1, seq_len) offset_mapping, - test_text + test_text, )[0] # Get first result print("\nšŸ“Š Character-Level Contribution Visualization:") @@ -187,7 +209,7 @@ def main(): for word in words: word_len = len(word) # Get attributions for this word - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: avg_attr = sum(word_attrs) / len(word_attrs) bar_length = int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 @@ -202,7 +224,7 @@ def main(): word_scores = [] for word in words: word_len = len(word) - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: word_scores.append((word, sum(word_attrs) / len(word_attrs))) char_idx += word_len + 1 @@ -214,33 +236,34 @@ def main(): except Exception as e: print(f"āš ļø Explainability failed: {e}") import traceback + traceback.print_exc() - + # Analysis completed for this example print(f"āœ… Analysis completed for example {i}") - + print(f"\nšŸŽ‰ Explainability analysis completed for {len(test_texts)} examples!") - + # Interactive section for user input (only if --interactive flag is provided) if "--interactive" in sys.argv: - print("\n" + "="*60) + print("\n" + "=" * 60) print("šŸŽÆ Interactive Explainability Mode") - print("="*60) + print("=" * 60) print("Enter your own text to see predictions and explanations!") print("Type 'quit' or 'exit' to end the session.\n") - + while True: try: user_text = input("šŸ’¬ Enter text: ").strip() - - if user_text.lower() in ['quit', 'exit', 'q']: + + if user_text.lower() in ["quit", "exit", "q"]: print("šŸ‘‹ Thanks for using the explainability tool!") break - + if not user_text: print("āš ļø Please enter some text.") continue - + print(f"\nšŸ” Analyzing: '{user_text}'") # Get prediction with explainability @@ -262,7 +285,7 @@ def main(): char_attributions = map_attributions_to_char( attributions.unsqueeze(0), # Add batch dimension: (1, seq_len) offset_mapping, - user_text + user_text, )[0] # Get first result print("\nšŸ“Š Character-Level Contribution Visualization:") @@ -279,10 +302,12 @@ def main(): for word in words: word_len = len(word) # Get attributions for this word - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: avg_attr = sum(word_attrs) / len(word_attrs) - bar_length = int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 + bar_length = ( + int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 + ) bar = "ā–ˆ" * bar_length print(f"{word:>15} | {bar:<40} {avg_attr:.4f}") char_idx += word_len + 1 # +1 for space @@ -294,23 +319,26 @@ def main(): word_scores = [] for word in words: word_len = len(word) - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: word_scores.append((word, sum(word_attrs) / len(word_attrs))) char_idx += word_len + 1 if word_scores: top_word, top_score = max(word_scores, key=lambda x: x[1]) - print(f"šŸ’” Most influential word: '{top_word}' (avg score: {top_score:.4f})") + print( + f"šŸ’” Most influential word: '{top_word}' (avg score: {top_score:.4f})" + ) except Exception as e: print(f"āš ļø Explainability failed: {e}") print("šŸ” Prediction available, but detailed explanation unavailable.") import traceback + traceback.print_exc() - - print("\n" + "-"*50) - + + print("\n" + "-" * 50) + except KeyboardInterrupt: print("\nšŸ‘‹ Session interrupted. Goodbye!") break @@ -318,9 +346,11 @@ def main(): print(f"āš ļø Error: {e}") continue else: - print("\nšŸ’” Tip: Use --interactive flag to enter interactive mode for custom text analysis!") + print( + "\nšŸ’” Tip: Use --interactive flag to enter interactive mode for custom text analysis!" + ) print(" Example: uv run python examples/simple_explainability_example.py --interactive") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/using_additional_features.py b/examples/using_additional_features.py index 90ab9b2..c740510 100644 --- a/examples/using_additional_features.py +++ b/examples/using_additional_features.py @@ -7,7 +7,6 @@ """ import os -import random import time import warnings @@ -20,9 +19,11 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + # Note: SimpleTextWrapper is not available in the current version # from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextConfig, SimpleTextWrapper + def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1): # Get unique labels and their frequencies unique_labels, label_counts = np.unique(y, return_counts=True) @@ -54,10 +55,10 @@ def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1): def merge_cat(cat): - if cat in ['World', 'Top News', 'Europe', 'Italia', 'U.S.', 'Top Stories']: - return 'World News' - if cat in ['Sci/Tech', 'Software and Developement', 'Toons', 'Health', 'Music Feeds']: - return 'Tech and Stuff' + if cat in ["World", "Top News", "Europe", "Italia", "U.S.", "Top Stories"]: + return "World News" + if cat in ["Sci/Tech", "Software and Developement", "Toons", "Health", "Music Feeds"]: + return "Tech and Stuff" return cat @@ -67,10 +68,9 @@ def load_and_prepare_data(): df = pd.read_parquet("https://minio.lab.sspcloud.fr/h4njlg/public/ag_news_full_1M.parquet") df = df.sample(10000, random_state=42) # Smaller sample to avoid disk space issues print(f"āœ… Loaded {len(df)} samples from AG NEWS dataset") - - df['category_final'] = df['category'].apply(lambda x: merge_cat(x)) - df['title_headline'] = df['title'] + '\n####\n' + df['description'] + df["category_final"] = df["category"].apply(lambda x: merge_cat(x)) + df["title_headline"] = df["title"] + "\n####\n" + df["description"] # categorical_features = None # text_feature = "title_headline" @@ -79,29 +79,32 @@ def load_and_prepare_data(): source_encoder = LabelEncoder() df["title_headline_processed"] = df["title_headline"] - df["source_encoded"] = source_encoder.fit_transform(df['source']) + df["source_encoded"] = source_encoder.fit_transform(df["source"]) - X_text_only = df[['title_headline_processed']].values - X_mixed = df[['title_headline_processed', "source_encoded"]].values - y = df['category_final'].values + X_text_only = df[["title_headline_processed"]].values + X_mixed = df[["title_headline_processed", "source_encoded"]].values + y = df["category_final"].values encoder = LabelEncoder() y = encoder.fit_transform(y) return X_text_only, X_mixed, y, encoder - + def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple=False): """Train and evaluate a FastText model""" print(f"\nšŸŽÆ Training {model_name}...") - - + # Split data twice: first for train/temp, then temp into validation/test X_train, X_temp, y_train, y_temp = stratified_split_rare_labels( - X, y, test_size=0.1 # 40% for validation + test + X, + y, + test_size=0.1, # 40% for validation + test ) X_val, X_test, y_val, y_test = stratified_split_rare_labels( - X_temp, y_temp, test_size=0.5 # Split temp 50/50 into validation and test + X_temp, + y_temp, + test_size=0.5, # Split temp 50/50 into validation and test ) - + # Note: SimpleTextWrapper is not available in the current version # The use_simple branch has been disabled if use_simple: @@ -133,22 +136,16 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple embedding_dim=50, categorical_vocabulary_sizes=vocab_sizes, categorical_embedding_dims=10, - num_classes=5 + num_classes=5, ) print(f" Categorical vocabulary sizes: {vocab_sizes}") else: # For text-only model - model_config = ModelConfig( - embedding_dim=50, - num_classes=5 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=5) # Create classifier print(" šŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print(" āœ… Classifier created successfully!") # Training configuration @@ -158,10 +155,7 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple lr=0.001, patience_early_stopping=3, num_workers=0, - trainer_params={ - 'enable_progress_bar': True, - 'deterministic': True - } + trainer_params={"enable_progress_bar": True, "deterministic": True}, ) # Create and build model @@ -170,10 +164,7 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple # Train model print(" šŸŽÆ Training model...") classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) training_time = time.time() - start_time @@ -201,26 +192,24 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple print(f" āš ļø Validation failed: {e}") test_accuracy = 0.0 predictions = np.zeros(len(y_test)) - + return { - 'model_name': model_name, - 'test_accuracy': test_accuracy, - 'training_time': training_time, - 'predictions': predictions, - 'y_test': y_test, - 'classifier': classifier + "model_name": model_name, + "test_accuracy": test_accuracy, + "training_time": training_time, + "predictions": predictions, + "y_test": y_test, + "classifier": classifier, } - - def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -232,10 +221,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸ”€ Classifier: Categorical Features Comparison") @@ -245,11 +231,11 @@ def main(): # Load and prepare data (same as notebook) X_text_only, X_mixed, y, encoder = load_and_prepare_data() - + # Train models - print(f"\nšŸš€ Training Models:") + print("\nšŸš€ Training Models:") print("-" * 40) - + # Text-only model results_text_only = train_and_evaluate_model( X_text_only, y, "Text-Only Classifier", use_categorical=False @@ -264,21 +250,26 @@ def main(): # results_tfidf = train_and_evaluate_model(X_text_only, y, "TF-IDF classifier", use_categorical=False, use_simple=True) # Compare results - print(f"\nšŸ“Š Results Comparison:") + print("\nšŸ“Š Results Comparison:") print("=" * 50) print(f"{'Model':<25}{'Test Acc':<11} {'Time (s)':<10}") print("-" * 50) - print(f"{'Text-Only':<25} " - f"{results_text_only['test_accuracy']:<11.3f} {results_text_only['training_time']:<10.1f}") - print(f"{'Mixed Features':<25} " - f"{results_mixed['test_accuracy']:<11.3f} {results_mixed['training_time']:<10.1f}") + print( + f"{'Text-Only':<25} " + f"{results_text_only['test_accuracy']:<11.3f} {results_text_only['training_time']:<10.1f}" + ) + print( + f"{'Mixed Features':<25} " + f"{results_mixed['test_accuracy']:<11.3f} {results_mixed['training_time']:<10.1f}" + ) # Calculate improvements - acc_improvement = results_mixed['test_accuracy'] - results_text_only['test_accuracy'] - time_overhead = results_mixed['training_time'] - results_text_only['training_time'] - + acc_improvement = results_mixed["test_accuracy"] - results_text_only["test_accuracy"] + time_overhead = results_mixed["training_time"] - results_text_only["training_time"] + print("-" * 50) print(f"Test Accuracy Improvement: {acc_improvement:+.3f}") print(f"Training Time Overhead: {time_overhead:+.1f}s") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/torchTextClassifiers/contrib/__init__.py b/torchTextClassifiers/contrib/__init__.py new file mode 100644 index 0000000..1f4bef9 --- /dev/null +++ b/torchTextClassifiers/contrib/__init__.py @@ -0,0 +1,13 @@ +"""contrib: example custom architectures for torchTextClassifiers. + +These classes are reference implementations that demonstrate how to build custom +PyTorch models compatible with the ``torchTextClassifiers.from_model`` entry point. +They are not part of the core API and may evolve independently. +""" + +from .multilevel import MultiLevelCrossEntropyLoss, MultiLevelTextClassificationModel + +__all__ = [ + "MultiLevelTextClassificationModel", + "MultiLevelCrossEntropyLoss", +] diff --git a/torchTextClassifiers/contrib/multilevel.py b/torchTextClassifiers/contrib/multilevel.py new file mode 100644 index 0000000..7a39755 --- /dev/null +++ b/torchTextClassifiers/contrib/multilevel.py @@ -0,0 +1,191 @@ +"""Multi-level (multi-task) classification architecture. + +Provides a custom model and a matching loss function for tasks that require +predicting several classification targets simultaneously from the same text +input — for example, hierarchical category codes (level 1, level 2, level 3) +predicted in a single forward pass. + +These classes are compatible with ``torchTextClassifiers.from_model`` and can +serve as a starting point for your own multi-task architectures. + +Example usage:: + + from torchTextClassifiers import torchTextClassifiers + from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, + ) + from torchTextClassifiers.model import TextClassificationModule + + model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=sentence_embedders, # one per level + classification_heads=classification_heads, # one per level + categorical_variable_net=categorical_var_net, + ) + + # Train with PyTorch Lightning directly + module = TextClassificationModule( + model=model, + loss=MultiLevelCrossEntropyLoss(), + optimizer=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + ) + + # Or wrap with the high-level API for predict() / save() / load() + classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, + ) +""" + +from typing import Optional + +import torch +from torch import nn + +from torchTextClassifiers.model.components import ( + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + SentenceEmbedder, + TokenEmbedder, +) + + +class MultiLevelTextClassificationModel(nn.Module): + """Multi-task text classifier that predicts several classes in one forward pass. + + Each classification level has its own ``SentenceEmbedder`` and + ``ClassificationHead`` but they all share the same ``TokenEmbedder``, so + the token-level representations are computed only once. + + Attributes: + num_classes: List of class counts, one entry per level. Required by + ``torchTextClassifiers.from_model``. + categorical_variable_net: The categorical embedding module (may be + ``None``). Required by ``torchTextClassifiers.from_model``. + + Args: + token_embedder: Shared token embedding module. + sentence_embedders: One ``SentenceEmbedder`` per classification level. + classification_heads: One ``ClassificationHead`` per level. + categorical_variable_net: Categorical feature embedding module. + """ + + def __init__( + self, + token_embedder: TokenEmbedder, + sentence_embedders: list[SentenceEmbedder], + classification_heads: list[ClassificationHead], + categorical_variable_net: CategoricalVariableNet, + ): + super().__init__() + self.token_embedder = token_embedder + self.sentence_embedders = nn.ModuleList(sentence_embedders) + self.classification_heads = nn.ModuleList(classification_heads) + self.categorical_variable_net = categorical_variable_net + self.num_classes: list[int] = [ + se.label_attention_config.num_classes + if se.label_attention_config is not None + else ch.num_classes + for se, ch in zip(sentence_embedders, classification_heads) + ] + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + categorical_vars: Optional[torch.Tensor] = None, + **kwargs, + ) -> list[torch.Tensor]: + """Run a forward pass and return one logit tensor per level. + + Args: + input_ids: Tokenised text, shape ``(batch, seq_len)``. + attention_mask: Padding mask, shape ``(batch, seq_len)``. + categorical_vars: Integer-encoded categorical features, + shape ``(batch, n_cats)``. May be ``None`` if no categorical + features are used. + + Returns: + List of raw logit tensors, one per classification level. + Each tensor has shape ``(batch, num_classes_at_level)``. + """ + token_embed_output = self.token_embedder(input_ids, attention_mask) + x_token = token_embed_output["token_embeddings"] + x_cat = self.categorical_variable_net(categorical_vars) + + outputs = [] + for sentence_embedder, classification_head in zip( + self.sentence_embedders, self.classification_heads + ): + if sentence_embedder.label_attention_config is not None: + num_cls = sentence_embedder.label_attention_config.num_classes + x_cat_level = x_cat.unsqueeze(1).expand(-1, num_cls, -1) + else: + x_cat_level = x_cat + + sentence_embedding = sentence_embedder( + token_embeddings=x_token, attention_mask=attention_mask + )["sentence_embedding"] + + fwd = self.categorical_variable_net.forward_type + if fwd in ( + CategoricalForwardType.AVERAGE_AND_CONCAT, + CategoricalForwardType.CONCATENATE_ALL, + ): + x_combined = torch.cat((sentence_embedding, x_cat_level), dim=-1) + else: + assert fwd == CategoricalForwardType.SUM_TO_TEXT + x_combined = sentence_embedding + x_cat_level + + outputs.append(classification_head(x_combined).squeeze(-1)) + + return outputs + + +class MultiLevelCrossEntropyLoss(nn.Module): + """Weighted cross-entropy loss across multiple classification levels. + + Averages the per-level cross-entropy losses, optionally weighting each + level by its number of classes so that finer-grained levels contribute + more to the total gradient. + + Args: + num_classes: If provided, level ``i`` is weighted by + ``num_classes[i] / sum(num_classes)``. If ``None``, all levels + are weighted equally. + + Example:: + + loss_fn = MultiLevelCrossEntropyLoss(num_classes=[5, 20, 100]) + # or unweighted: + loss_fn = MultiLevelCrossEntropyLoss() + """ + + def __init__(self, num_classes: Optional[list[int]] = None): + super().__init__() + self.num_classes = num_classes + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, outputs: list[torch.Tensor], labels: torch.Tensor) -> torch.Tensor: + """Compute the weighted average loss. + + Args: + outputs: List of logit tensors ``(batch, num_classes_i)`` returned + by ``MultiLevelTextClassificationModel``. + labels: Integer label tensor of shape ``(batch, n_levels)``. + Column ``i`` contains the ground-truth label for level ``i``. + + Returns: + Scalar loss tensor. + """ + total_loss = torch.tensor(0.0, device=outputs[0].device) + for idx, output in enumerate(outputs): + label = labels[:, idx] + weight = self.num_classes[idx] if self.num_classes is not None else 1 + total_loss = total_loss + self.loss_fn(output.squeeze(), label) * weight + + total_weight = sum(self.num_classes) if self.num_classes is not None else len(outputs) + return total_loss / total_weight diff --git a/torchTextClassifiers/dataset/dataset.py b/torchTextClassifiers/dataset/dataset.py index 6475567..064d78b 100644 --- a/torchTextClassifiers/dataset/dataset.py +++ b/torchTextClassifiers/dataset/dataset.py @@ -52,6 +52,8 @@ def __init__( logger.warning( "ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong." ) + elif not self.ragged_multilabel and self.labels is not None: + self.labels = torch.tensor(labels, dtype=torch.long) def __len__(self): return len(self.texts) @@ -100,7 +102,7 @@ def collate_fn(self, batch): labels_tensor[rows, cols] = 1 else: - labels_tensor = torch.tensor(labels) + labels_tensor = torch.stack(list(labels)) else: labels_tensor = None diff --git a/torchTextClassifiers/model/components/__init__.py b/torchTextClassifiers/model/components/__init__.py index 3db1a73..07944f0 100644 --- a/torchTextClassifiers/model/components/__init__.py +++ b/torchTextClassifiers/model/components/__init__.py @@ -9,5 +9,15 @@ ) from .classification_head import ClassificationHead as ClassificationHead from .text_embedder import LabelAttentionConfig as LabelAttentionConfig -from .text_embedder import TokenEmbedder as TokenEmbedder, TokenEmbedderConfig as TokenEmbedderConfig -from .text_embedder import SentenceEmbedder as SentenceEmbedder, SentenceEmbedderConfig as SentenceEmbedderConfig \ No newline at end of file +from .text_embedder import ( + SentenceEmbedder as SentenceEmbedder, +) +from .text_embedder import ( + SentenceEmbedderConfig as SentenceEmbedderConfig, +) +from .text_embedder import ( + TokenEmbedder as TokenEmbedder, +) +from .text_embedder import ( + TokenEmbedderConfig as TokenEmbedderConfig, +) diff --git a/torchTextClassifiers/model/components/classification_head.py b/torchTextClassifiers/model/components/classification_head.py index bd21f52..475957c 100644 --- a/torchTextClassifiers/model/components/classification_head.py +++ b/torchTextClassifiers/model/components/classification_head.py @@ -25,7 +25,7 @@ def __init__( super().__init__() if net is not None: self.net = net - + # --- Custom net should either be a Sequential or a Linear --- if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)): raise ValueError("net must be an nn.Sequential when provided.") diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index 8726f20..1b9ed94 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -1,9 +1,8 @@ import pytorch_lightning as pl import torch +from torch import nn from torchmetrics import Accuracy -from .model import TextClassificationModel - # ============================================================================ # PyTorch Lightning Module # ============================================================================ @@ -14,7 +13,7 @@ class TextClassificationModule(pl.LightningModule): def __init__( self, - model: TextClassificationModel, + model: nn.Module, loss, optimizer, optimizer_params, @@ -36,11 +35,23 @@ def __init__( scheduler_interval: Scheduler interval. """ super().__init__() - self.save_hyperparameters(ignore=["model"]) + self.save_hyperparameters(ignore=["model", "loss"]) self.model = model self.loss = loss - self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes) + + if not hasattr(self.model, "num_classes") or self.model.num_classes is None: + raise ValueError("Model must have num_classes attribute for accuracy calculation.") + + if isinstance(self.model.num_classes, list): + self.accuracy_fn = torch.nn.ModuleList( + [Accuracy(task="multiclass", num_classes=n) for n in self.model.num_classes] + ) + self.multilevel_accuracy = True + else: + self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes) + self.multilevel_accuracy = False + self.optimizer = optimizer self.optimizer_params = optimizer_params self.scheduler = scheduler @@ -62,71 +73,42 @@ def forward(self, batch) -> torch.Tensor: categorical_vars=batch.get("categorical_vars", None), ) - def training_step(self, batch, batch_idx: int) -> torch.Tensor: - """ - Training step. - - Args: - batch (List[torch.LongTensor]): Training batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - + def step(self, batch) -> tuple[torch.Tensor, torch.Tensor | list[torch.Tensor]]: targets = batch["labels"] - outputs = self.forward(batch) - if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): targets = targets.float() - loss = self.loss(outputs, targets) - self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) - accuracy = self.accuracy_fn(outputs, targets) - self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True) + if self.multilevel_accuracy: + accuracy = [ + fn(out, targets[:, i]) for i, (fn, out) in enumerate(zip(self.accuracy_fn, outputs)) + ] + else: + accuracy = self.accuracy_fn(outputs, targets) + return loss, accuracy - torch.cuda.empty_cache() + def _log_accuracy(self, accuracy: torch.Tensor | list[torch.Tensor], prefix: str, **kwargs): + if isinstance(accuracy, list): + for i, acc in enumerate(accuracy): + self.log(f"{prefix}_accuracy_level_{i}", acc, **kwargs) + else: + self.log(f"{prefix}_accuracy", accuracy, **kwargs) + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + loss, accuracy = self.step(batch) + self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) + self._log_accuracy(accuracy, "train", on_epoch=True, on_step=False, prog_bar=True) + torch.cuda.empty_cache() return loss def validation_step(self, batch, batch_idx: int): - """ - Validation step. - - Args: - batch (List[torch.LongTensor]): Validation batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - targets = batch["labels"] - - outputs = self.forward(batch) - - loss = self.loss(outputs, targets) + loss, accuracy = self.step(batch) self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True) - - accuracy = self.accuracy_fn(outputs, targets) - self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True) + self._log_accuracy(accuracy, "val", on_epoch=True, on_step=False, prog_bar=True) return loss def test_step(self, batch, batch_idx: int): - """ - Test step. - - Args: - batch (List[torch.LongTensor]): Test batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - targets = batch["labels"] - - outputs = self.forward(batch) - loss = self.loss(outputs, targets) - - accuracy = self.accuracy_fn(outputs, targets) - + loss, accuracy = self.step(batch) return loss, accuracy def predict_step(self, batch, batch_idx: int = 0, dataloader_idx: int = 0): diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index b5aeb10..3b05477 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -1,848 +1,987 @@ -import logging -import pickle -import time -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -try: - from captum.attr import LayerIntegratedGradients - - HAS_CAPTUM = True -except ImportError: - HAS_CAPTUM = False - - -import numpy as np -import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import ( - EarlyStopping, - LearningRateMonitor, - ModelCheckpoint, -) - -from torchTextClassifiers.dataset import TextClassificationDataset -from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule -from torchTextClassifiers.model.components import ( - AttentionConfig, - CategoricalForwardType, - CategoricalVariableNet, - ClassificationHead, - LabelAttentionConfig, - SentenceEmbedder, - SentenceEmbedderConfig, - TokenEmbedder, - TokenEmbedderConfig, -) -from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput -from torchTextClassifiers.value_encoder import ValueEncoder - -logger = logging.getLogger(__name__) - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler()], -) - - -@dataclass -class ModelConfig: - """Base configuration class for text classifiers.""" - - embedding_dim: int - num_classes: Optional[int] = None - categorical_vocabulary_sizes: Optional[List[int]] = None - categorical_embedding_dims: Optional[Union[List[int], int]] = None - attention_config: Optional[AttentionConfig] = None - n_heads_label_attention: Optional[int] = None - aggregation_method: Optional[str] = "mean" - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig": - return cls(**data) - - -@dataclass -class TrainingConfig: - num_epochs: int - batch_size: int - lr: float - raw_categorical_inputs: Optional[bool] = True - raw_labels: Optional[bool] = True - loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss()) - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam - scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None - accelerator: str = "auto" - num_workers: int = 12 - patience_early_stopping: int = 3 - dataloader_params: Optional[dict] = None - trainer_params: Optional[dict] = None - optimizer_params: Optional[dict] = None - scheduler_params: Optional[dict] = None - save_path: Optional[str] = "my_ttc" - - def to_dict(self) -> Dict[str, Any]: - data = asdict(self) - # Serialize loss and scheduler as their class names - data["loss"] = self.loss.__class__.__name__ - if self.scheduler is not None: - data["scheduler"] = self.scheduler.__name__ - return data - - -class torchTextClassifiers: - """Generic text classifier framework supporting multiple architectures. - - Given a tokenizer and model configuration, this class initializes: - - Text embedding layer (if needed) - - Categorical variable embedding network (if categorical variables are provided) - - Classification head - The resulting model can be trained using PyTorch Lightning and used for predictions. - - """ - - def __init__( - self, - tokenizer: BaseTokenizer, - model_config: ModelConfig, - ragged_multilabel: bool = False, - value_encoder: Optional[ValueEncoder] = None, - ): - """Initialize the torchTextClassifiers instance. - - Args: - tokenizer: A tokenizer instance for text preprocessing - model_config: Configuration parameters for the text classification model - ragged_multilabel: Whether to use ragged multilabel classification - value_encoder: Optional ValueEncoder for encoding - raw string (or mixed) categorical values to integers. Build it - beforehand from DictEncoder or sklearn LabelEncoder instances and - pass it here. If None, categorical columns in X must already be - integer-encoded. - - Example: - >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers - >>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder - >>> # Build one DictEncoder per categorical feature - >>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))}) - ... for i in range(X_categorical.shape[1])} - >>> encoder = ValueEncoder(encoders) - >>> model_config = ModelConfig( - ... embedding_dim=10, - ... categorical_vocabulary_sizes=encoder.vocabulary_sizes, - ... categorical_embedding_dims=[10, 5], - ... num_classes=10, - ... ) - >>> ttc = torchTextClassifiers( - ... tokenizer=tokenizer, - ... model_config=model_config, - ... value_encoder=encoder, - ... ) - """ - - self.model_config = model_config - self.tokenizer = tokenizer - self.ragged_multilabel = ragged_multilabel - self.value_encoder: ValueEncoder | None = value_encoder - - if hasattr(self.tokenizer, "trained"): - if not self.tokenizer.trained: - raise RuntimeError( - f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier." - ) - - self.vocab_size = tokenizer.vocab_size - self.embedding_dim = model_config.embedding_dim - - if self.value_encoder is not None: - if (model_config.num_classes != self.value_encoder.num_classes) or ( - model_config.categorical_vocabulary_sizes != self.value_encoder.vocabulary_sizes - ): - logger.info( - "Overriding model_config num_classes and/or categorical_vocabulary_sizes with values from value_encoder." - ) - self.categorical_vocabulary_sizes = self.value_encoder.vocabulary_sizes - self.num_classes = self.value_encoder.num_classes - else: - self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes - if model_config.num_classes is None: - raise ValueError( - "num_classes must be specified in the model configuration if no value_encoder is provided." - ) - self.num_classes = model_config.num_classes - - self.enable_label_attention = model_config.n_heads_label_attention is not None - - if self.tokenizer.output_vectorized: - self.token_embedder = None - logger.info( - "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization." - ) - self.embedding_dim = self.tokenizer.output_dim - else: - token_embedder_config = TokenEmbedderConfig( - vocab_size=self.vocab_size, - embedding_dim=self.embedding_dim, - padding_idx=tokenizer.padding_idx, - attention_config=model_config.attention_config, - ) - sentence_embedder_config = SentenceEmbedderConfig( - label_attention_config=LabelAttentionConfig( - n_head=model_config.n_heads_label_attention, - num_classes=model_config.num_classes, - embedding_dim=self.embedding_dim, - ) - if self.enable_label_attention - else None, - aggregation_method=model_config.aggregation_method, - ) - self.token_embedder = TokenEmbedder( - token_embedder_config=token_embedder_config, - ) - self.sentence_embedder = SentenceEmbedder( - sentence_embedder_config=sentence_embedder_config - ) - - classif_head_input_dim = self.embedding_dim - if self.categorical_vocabulary_sizes: - self.categorical_var_net = CategoricalVariableNet( - categorical_vocabulary_sizes=self.categorical_vocabulary_sizes, - categorical_embedding_dims=model_config.categorical_embedding_dims, - text_embedding_dim=self.embedding_dim, - ) - - if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: - classif_head_input_dim += self.categorical_var_net.output_dim - - else: - self.categorical_var_net = None - - self.classification_head = ClassificationHead( - input_dim=classif_head_input_dim, - num_classes=1 - if self.enable_label_attention - else self.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) - ) - - self.pytorch_model = TextClassificationModel( - token_embedder=self.token_embedder, - sentence_embedder=self.sentence_embedder, - categorical_variable_net=self.categorical_var_net, - classification_head=self.classification_head, - ) - - def train( - self, - X_train: np.ndarray, - y_train: np.ndarray, - training_config: TrainingConfig, - X_val: Optional[np.ndarray] = None, - y_val: Optional[np.ndarray] = None, - verbose: bool = False, - ) -> None: - """Train the classifier using PyTorch Lightning. - - This method handles the complete training process including: - - Data validation and preprocessing - - Dataset and DataLoader creation - - PyTorch Lightning trainer setup with callbacks - - Model training with early stopping - - Best model loading after training - - Note on Checkpoints: - After training, the best model checkpoint is automatically loaded. - This checkpoint contains the full training state (model weights, - optimizer, and scheduler state). Loading uses weights_only=False - as the checkpoint is self-generated and trusted. - - Args: - X_train: Training input data - y_train: Training labels - X_val: Validation input data - y_val: Validation labels - training_config: Configuration parameters for training - verbose: Whether to print training progress information - - - Example: - - >>> training_config = TrainingConfig( - ... lr=1e-3, - ... batch_size=4, - ... num_epochs=1, - ... ) - >>> ttc.train( - ... X_train=X, - ... y_train=Y, - ... X_val=X, - ... y_val=Y, - ... training_config=training_config, - ... ) - """ - - # Input validation - X_train, y_train = self._check_XY( - X_train, y_train, training_config.raw_categorical_inputs, training_config.raw_labels - ) - - if X_val is not None: - assert y_val is not None, "y_val must be provided if X_val is provided." - if y_val is not None: - assert X_val is not None, "X_val must be provided if y_val is provided." - - X_val: Optional[Dict[str, Any]] = None - if X_val is not None and y_val is not None: - X_val, y_val = self._check_XY(X_val, y_val) - - if ( - X_train["categorical_variables"] is not None - and X_val is not None - and X_val["categorical_variables"] is not None - ): - assert ( - X_train["categorical_variables"].ndim > 1 - and X_train["categorical_variables"].shape[1] - == X_val["categorical_variables"].shape[1] - or X_val["categorical_variables"].ndim == 1 - ), "X_train and X_val must have the same number of columns." - - if verbose: - logger.info("Starting training process...") - - if training_config.accelerator == "auto": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device(training_config.accelerator) - - self.device = device - - optimizer_params = {"lr": training_config.lr} - if training_config.optimizer_params is not None: - optimizer_params.update(training_config.optimizer_params) - - if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel: - logger.warning( - "āš ļø You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks." - ) - - self.lightning_module = TextClassificationModule( - model=self.pytorch_model, - loss=training_config.loss, - optimizer=training_config.optimizer, - optimizer_params=optimizer_params, - scheduler=training_config.scheduler, - scheduler_params=training_config.scheduler_params - if training_config.scheduler_params - else {}, - scheduler_interval="epoch", - ) - - self.pytorch_model.to(self.device) - - if verbose: - logger.info(f"Running on: {device}") - - train_dataset = TextClassificationDataset( - texts=X_train["text"], - categorical_variables=X_train["categorical_variables"], # None if no cat vars - tokenizer=self.tokenizer, - labels=y_train.tolist(), - ragged_multilabel=self.ragged_multilabel, - ) - train_dataloader = train_dataset.create_dataloader( - batch_size=training_config.batch_size, - num_workers=training_config.num_workers, - shuffle=True, - **training_config.dataloader_params if training_config.dataloader_params else {}, - ) - - if X_val is not None and y_val is not None: - val_dataset = TextClassificationDataset( - texts=X_val["text"], - categorical_variables=X_val["categorical_variables"], # None if no cat vars - tokenizer=self.tokenizer, - labels=y_val, - ragged_multilabel=self.ragged_multilabel, - ) - val_dataloader = val_dataset.create_dataloader( - batch_size=training_config.batch_size, - num_workers=training_config.num_workers, - shuffle=False, - **training_config.dataloader_params if training_config.dataloader_params else {}, - ) - else: - val_dataloader = None - - # Setup trainer - callbacks = [ - ModelCheckpoint( - monitor="val_loss" if val_dataloader is not None else "train_loss", - save_top_k=1, - save_last=False, - mode="min", - ), - EarlyStopping( - monitor="val_loss" if val_dataloader is not None else "train_loss", - patience=training_config.patience_early_stopping, - mode="min", - ), - LearningRateMonitor(logging_interval="step"), - ] - - trainer_params = { - "accelerator": training_config.accelerator, - "callbacks": callbacks, - "max_epochs": training_config.num_epochs, - "num_sanity_val_steps": 2, - "strategy": "auto", - "log_every_n_steps": 1, - "enable_progress_bar": True, - } - - if training_config.trainer_params is not None: - trainer_params.update(training_config.trainer_params) - - trainer = pl.Trainer(**trainer_params) - - torch.cuda.empty_cache() - torch.set_float32_matmul_precision("medium") - - if verbose: - logger.info("Launching training...") - start = time.time() - - trainer.fit(self.lightning_module, train_dataloader, val_dataloader) - - if verbose: - end = time.time() - logger.info(f"Training completed in {end - start:.2f} seconds.") - - best_model_path = trainer.checkpoint_callback.best_model_path - self.checkpoint_path = best_model_path - - self.lightning_module = TextClassificationModule.load_from_checkpoint( - best_model_path, - model=self.pytorch_model, - loss=training_config.loss, - weights_only=False, # Required: checkpoint contains optimizer/scheduler state - ) - - self.pytorch_model = self.lightning_module.model.to(self.device) - - self.save_path = training_config.save_path - self.save(self.save_path) - - self.lightning_module.eval() - - def _check_XY( - self, X: np.ndarray, Y: np.ndarray, raw_categorical_inputs, raw_labels - ) -> Tuple[Dict[str, Any], np.ndarray]: - X_checked = self._check_X(X, raw_categorical_inputs) - Y_checked = self._check_Y(Y, raw_labels) - - if X_checked["text"].shape[0] != len(Y_checked): - raise ValueError("X_train and y_train must have the same number of observations.") - - return X_checked, Y_checked - - @staticmethod - def _check_text_col(X): - assert isinstance( - X, np.ndarray - ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables." - - try: - if X.ndim > 1: - text = X[:, 0].astype(str) - else: - text = X[:].astype(str) - except ValueError: - logger.error("The first column of X must be castable in string format.") - - return text - - def _check_categorical_variables( - self, X: np.ndarray, raw_categorical_inputs: bool - ) -> np.ndarray: - """Validate and encode categorical variables from X. - - If a ``value_encoder`` was provided at initialization, raw string - or mixed values are encoded to integers via that encoder. Otherwise the - categorical columns must already be integer-encodable. - - Args: - X: Full input array whose first column is text and whose remaining - columns are categorical variables. - - Returns: - Integer-encoded categorical array of shape (N, n_cat_features). - - Raises: - ValueError: If the number of categorical features does not match the - model configuration, if values exceed vocabulary bounds, or if - values cannot be cast to integers and no encoder was provided. - """ - assert self.categorical_var_net is not None - - num_cat_vars = X.shape[1] - 1 if X.ndim > 1 else 0 - - if num_cat_vars != self.categorical_var_net.num_categorical_features: - raise ValueError( - f"X must have the same number of categorical variables as the number of " - f"embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})." - ) - - if raw_categorical_inputs: - if self.value_encoder is None: - raise ValueError( - "Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers." - ) - categorical_variables = self.value_encoder.transform(X[:, 1:]).astype(int) - else: - categorical_variables = X[:, 1:].astype(int) - - for j in range(num_cat_vars): - max_cat_value = categorical_variables[:, j].max() - if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]: - raise ValueError( - f"Categorical variable at index {j} has value {max_cat_value} which exceeds " - f"the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}." - ) - - return categorical_variables - - def _check_X(self, X: np.ndarray, raw_categorical_inputs: bool) -> Dict[str, Any]: - text = self._check_text_col(X) - - categorical_variables = None - if self.categorical_var_net is not None: - categorical_variables = self._check_categorical_variables(X, raw_categorical_inputs) - - return {"text": text, "categorical_variables": categorical_variables} - - def _check_Y(self, Y, raw_labels: bool) -> np.ndarray: - if self.ragged_multilabel: - assert isinstance( - Y, list - ), "Y must be a list of lists for ragged multilabel classification." - for row in Y: - assert isinstance(row, list), "Each element of Y must be a list of labels." - - return Y - - else: - assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." - assert ( - len(Y.shape) == 1 or len(Y.shape) == 2 - ), "Y must be a numpy array of shape (N,) or (N, num_labels)." - - if raw_labels: - if self.value_encoder is None: - raise ValueError( - "Raw label encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw labels to integers." - ) - Y = self.value_encoder.transform_labels(Y) - Y = Y.astype(int) - - if Y.max() >= self.num_classes or Y.min() < 0: - raise ValueError( - f"Y contains class labels outside the range [0, {self.num_classes - 1}]." - ) - - return Y - - def predict( - self, - X_test: np.ndarray, - raw_categorical_inputs: bool = True, - top_k=1, - explain_with_label_attention: bool = False, - explain_with_captum=False, - ): - """ - Args: - X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables - top_k (int): for each sentence, return the top_k most likely predictions (default: 1) - explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) - explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) - - Returns: A dictionary containing the following fields: - - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. - - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores. - - if explain is True: - - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. - """ - - explain = explain_with_label_attention or explain_with_captum - if explain: - return_offsets_mapping = True # to be passed to the tokenizer - return_word_ids = True - if self.pytorch_model.token_embedder is None: - raise RuntimeError( - "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." - ) - else: - if explain_with_captum: - if not HAS_CAPTUM: - raise ImportError( - "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." - ) - lig = LayerIntegratedGradients( - self.pytorch_model, self.pytorch_model.token_embedder.embedding_layer - ) # initialize a Captum layer gradient integrator - if explain_with_label_attention: - if not self.enable_label_attention: - raise RuntimeError( - "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." - ) - else: - return_offsets_mapping = False - return_word_ids = False - - X_test = self._check_X(X_test, raw_categorical_inputs) - text = X_test["text"] - categorical_variables = X_test["categorical_variables"] - - self.pytorch_model.eval().cpu() - - tokenize_output = self.tokenizer.tokenize( - text.tolist(), - return_offsets_mapping=return_offsets_mapping, - return_word_ids=return_word_ids, - ) - - if not isinstance(tokenize_output, TokenizerOutput): - raise TypeError( - f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method." - ) - - encoded_text = tokenize_output.input_ids # (batch_size, seq_len) - attention_mask = tokenize_output.attention_mask # (batch_size, seq_len) - - if categorical_variables is not None: - categorical_vars = torch.tensor( - categorical_variables, dtype=torch.float32 - ) # (batch_size, num_categorical_features) - else: - categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) - - model_output = self.pytorch_model( - encoded_text, - attention_mask, - categorical_vars, - return_label_attention_matrix=explain_with_label_attention, - ) # forward pass, contains the prediction scores (len(text), num_classes) - pred = ( - model_output["logits"] if explain_with_label_attention else model_output - ) # (batch_size, num_classes) - - label_attention_matrix = ( - model_output["label_attention_matrix"] if explain_with_label_attention else None - ) - - label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities - - label_scores_topk = torch.topk(label_scores, k=top_k, dim=1) - - integer_predictions = label_scores_topk.indices # integer class indices (needed for captum) - if self.value_encoder is not None: - predictions = self.value_encoder.inverse_transform_labels(integer_predictions.numpy()) - else: - predictions = integer_predictions - - confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores - - if explain: - if explain_with_captum: - # Captum explanations - captum_attributions = [] - for k in range(top_k): - attributions = lig.attribute( - (encoded_text, attention_mask, categorical_vars), - target=integer_predictions[:, k], - ) # (batch_size, seq_len) - attributions = attributions.sum(dim=-1) - captum_attributions.append(attributions.detach().cpu()) - - captum_attributions = torch.stack( - captum_attributions, dim=1 - ) # (batch_size, top_k, seq_len) - else: - captum_attributions = None - - return { - "prediction": predictions, - "confidence": confidence, - "captum_attributions": captum_attributions, - "label_attention_attributions": label_attention_matrix, - "offset_mapping": tokenize_output.offset_mapping, - "word_ids": tokenize_output.word_ids, - } - else: - return { - "prediction": predictions, - "confidence": confidence, - } - - def save(self, path: Union[str, Path]) -> None: - """Save the complete torchTextClassifiers instance to disk. - - This saves: - - Model configuration - - Tokenizer state - - PyTorch Lightning checkpoint (if trained) - - All other instance attributes - - Args: - path: Directory path where the model will be saved - - Example: - >>> ttc = torchTextClassifiers(tokenizer, model_config) - >>> ttc.train(X_train, y_train, training_config) - >>> ttc.save("my_model") - """ - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - # Save the checkpoint if model has been trained - checkpoint_path = None - if hasattr(self, "lightning_module"): - checkpoint_path = path / "model_checkpoint.ckpt" - # Save the current state as a checkpoint - trainer = pl.Trainer() - trainer.strategy.connect(self.lightning_module) - trainer.save_checkpoint(checkpoint_path) - - # Prepare metadata to save - metadata = { - "model_config": self.model_config.to_dict(), - "ragged_multilabel": self.ragged_multilabel, - "vocab_size": self.vocab_size, - "embedding_dim": self.embedding_dim, - "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes, - "num_classes": self.num_classes, - "checkpoint_path": str(checkpoint_path) if checkpoint_path else None, - "device": str(self.device) if hasattr(self, "device") else None, - "has_value_encoder": self.value_encoder is not None, - } - - # Save metadata - with open(path / "metadata.pkl", "wb") as f: - pickle.dump(metadata, f) - - # Save tokenizer - tokenizer_path = path / "tokenizer.pkl" - with open(tokenizer_path, "wb") as f: - pickle.dump(self.tokenizer, f) - - # Save categorical encoder if present - if self.value_encoder is not None: - with open(path / "value_encoder.pkl", "wb") as f: - pickle.dump(self.value_encoder, f) - - logger.info(f"Model saved successfully to {path}") - - @classmethod - def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers": - """Load a torchTextClassifiers instance from disk. - - Args: - path: Directory path where the model was saved - device: Device to load the model on ('auto', 'cpu', 'cuda', etc.) - - Returns: - Loaded torchTextClassifiers instance - - Example: - >>> loaded_ttc = torchTextClassifiers.load("my_model") - >>> predictions = loaded_ttc.predict(X_test) - """ - path = Path(path) - - if not path.exists(): - raise FileNotFoundError(f"Model directory not found: {path}") - - # Load metadata - with open(path / "metadata.pkl", "rb") as f: - metadata = pickle.load(f) - - # Load tokenizer - with open(path / "tokenizer.pkl", "rb") as f: - tokenizer = pickle.load(f) - - # Reconstruct model_config - model_config = ModelConfig.from_dict(metadata["model_config"]) - - # Load categorical encoder if one was saved - value_encoder = None - if metadata.get("has_value_encoder"): - encoder_path = path / "value_encoder.pkl" - if encoder_path.exists(): - with open(encoder_path, "rb") as f: - value_encoder = pickle.load(f) - - # Create instance - instance = cls( - tokenizer=tokenizer, - model_config=model_config, - ragged_multilabel=metadata["ragged_multilabel"], - value_encoder=value_encoder, - ) - - # Set device - if device == "auto": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device(device) - instance.device = device - - # Load checkpoint if it exists - if metadata["checkpoint_path"]: - checkpoint_path = path / "model_checkpoint.ckpt" - if checkpoint_path.exists(): - # Load the checkpoint with weights_only=False since it's our own trusted checkpoint - instance.lightning_module = TextClassificationModule.load_from_checkpoint( - str(checkpoint_path), - model=instance.pytorch_model, - weights_only=False, - ) - instance.pytorch_model = instance.lightning_module.model.to(device) - instance.checkpoint_path = str(checkpoint_path) - logger.info(f"Model checkpoint loaded from {checkpoint_path}") - else: - logger.warning(f"Checkpoint file not found at {checkpoint_path}") - - logger.info(f"Model loaded successfully from {path}") - return instance - - def __repr__(self): - model_type = ( - self.lightning_module.__repr__() - if hasattr(self, "lightning_module") - else self.pytorch_model.__repr__() - ) - - tokenizer_info = self.tokenizer.__repr__() - - cat_forward_type = ( - self.categorical_var_net.forward_type.name - if self.categorical_var_net is not None - else "None" - ) - - lines = [ - "torchTextClassifiers(", - f" tokenizer = {tokenizer_info},", - f" model = {model_type},", - f" categorical_forward_type = {cat_forward_type},", - f" num_classes = {self.model_config.num_classes},", - f" embedding_dim = {self.embedding_dim},", - ")", - ] - return "\n".join(lines) +import logging +import pickle +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast + +try: + from captum.attr import LayerIntegratedGradients + + HAS_CAPTUM = True +except ImportError: + HAS_CAPTUM = False + + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) +from torch import nn + +from torchTextClassifiers.dataset import TextClassificationDataset +from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, +) +from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput +from torchTextClassifiers.value_encoder import ValueEncoder + +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], +) + + +@dataclass +class ModelConfig: + """Base configuration class for text classifiers.""" + + embedding_dim: int + num_classes: Optional[int | list[int]] = None + categorical_vocabulary_sizes: Optional[List[int]] = None + categorical_embedding_dims: Optional[Union[List[int], int]] = None + attention_config: Optional[AttentionConfig] = None + n_heads_label_attention: Optional[int] = None + aggregation_method: Optional[str] = "mean" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig": + return cls(**data) + + +@dataclass +class TrainingConfig: + num_epochs: int + batch_size: int + lr: float + raw_categorical_inputs: Optional[bool] = True + raw_labels: Optional[bool] = True + loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss()) + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam + scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None + accelerator: str = "auto" + num_workers: int = 12 + patience_early_stopping: int = 3 + dataloader_params: Optional[dict] = None + trainer_params: Optional[dict] = None + optimizer_params: Optional[dict] = None + scheduler_params: Optional[dict] = None + save_path: Optional[str] = "my_ttc" + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self) + # Serialize loss and scheduler as their class names + data["loss"] = self.loss.__class__.__name__ + if self.scheduler is not None: + data["scheduler"] = self.scheduler.__name__ + return data + + +class torchTextClassifiers: + """Generic text classifier framework supporting multiple architectures. + + Given a tokenizer and model configuration, this class initializes: + - Text embedding layer (if needed) + - Categorical variable embedding network (if categorical variables are provided) + - Classification head + The resulting model can be trained using PyTorch Lightning and used for predictions. + + """ + + def __init__( + self, + tokenizer: BaseTokenizer, + model_config: ModelConfig, + ragged_multilabel: bool = False, + value_encoder: Optional[ValueEncoder] = None, + ): + """Initialize the torchTextClassifiers instance. + + Args: + tokenizer: A tokenizer instance for text preprocessing + model_config: Configuration parameters for the text classification model + ragged_multilabel: Whether to use ragged multilabel classification + value_encoder: Optional ValueEncoder for encoding + raw string (or mixed) categorical values to integers. Build it + beforehand from DictEncoder or sklearn LabelEncoder instances and + pass it here. If None, categorical columns in X must already be + integer-encoded. + + Example: + >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers + >>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder + >>> # Build one DictEncoder per categorical feature + >>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))}) + ... for i in range(X_categorical.shape[1])} + >>> encoder = ValueEncoder(encoders) + >>> model_config = ModelConfig( + ... embedding_dim=10, + ... categorical_vocabulary_sizes=encoder.vocabulary_sizes, + ... categorical_embedding_dims=[10, 5], + ... num_classes=10, + ... ) + >>> ttc = torchTextClassifiers( + ... tokenizer=tokenizer, + ... model_config=model_config, + ... value_encoder=encoder, + ... ) + """ + + self.model_config = model_config + self.tokenizer = tokenizer + self.ragged_multilabel = ragged_multilabel + self.value_encoder: ValueEncoder | None = value_encoder + + if hasattr(self.tokenizer, "trained"): + if not self.tokenizer.trained: + raise RuntimeError( + f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier." + ) + + self.vocab_size = tokenizer.vocab_size + self.embedding_dim = model_config.embedding_dim + + if self.value_encoder is not None: + if (model_config.num_classes != self.value_encoder.num_classes) or ( + model_config.categorical_vocabulary_sizes != self.value_encoder.vocabulary_sizes + ): + logger.info( + "Overriding model_config num_classes and/or categorical_vocabulary_sizes with values from value_encoder." + ) + self.categorical_vocabulary_sizes = self.value_encoder.vocabulary_sizes + self.num_classes = self.value_encoder.num_classes + else: + self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes + if model_config.num_classes is None: + raise ValueError( + "num_classes must be specified in the model configuration if no value_encoder is provided." + ) + self.num_classes = model_config.num_classes + + self.enable_label_attention = model_config.n_heads_label_attention is not None + + if self.tokenizer.output_vectorized: + self.token_embedder = None + logger.info( + "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization." + ) + self.embedding_dim = self.tokenizer.output_dim + else: + token_embedder_config = TokenEmbedderConfig( + vocab_size=self.vocab_size, + embedding_dim=self.embedding_dim, + padding_idx=tokenizer.padding_idx, + attention_config=model_config.attention_config, + ) + sentence_embedder_config = SentenceEmbedderConfig( + label_attention_config=LabelAttentionConfig( + n_head=model_config.n_heads_label_attention, + num_classes=model_config.num_classes, + embedding_dim=self.embedding_dim, + ) + if self.enable_label_attention + else None, + aggregation_method=model_config.aggregation_method, + ) + self.token_embedder = TokenEmbedder( + token_embedder_config=token_embedder_config, + ) + self.sentence_embedder = SentenceEmbedder( + sentence_embedder_config=sentence_embedder_config + ) + + classif_head_input_dim = self.embedding_dim + if self.categorical_vocabulary_sizes: + self.categorical_var_net = CategoricalVariableNet( + categorical_vocabulary_sizes=self.categorical_vocabulary_sizes, + categorical_embedding_dims=model_config.categorical_embedding_dims, + text_embedding_dim=self.embedding_dim, + ) + + if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: + classif_head_input_dim += self.categorical_var_net.output_dim + + else: + self.categorical_var_net = None + + self.classification_head = ClassificationHead( + input_dim=classif_head_input_dim, + num_classes=1 + if self.enable_label_attention + else self.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) + ) + + self.pytorch_model = TextClassificationModel( + token_embedder=self.token_embedder, + sentence_embedder=self.sentence_embedder, + categorical_variable_net=self.categorical_var_net, + classification_head=self.classification_head, + ) + + @classmethod + def from_model( + cls, + tokenizer: BaseTokenizer, + pytorch_model: nn.Module, + value_encoder: Optional[ValueEncoder] = None, + ragged_multilabel: Optional[bool] = False, + ): + """Initialize torchTextClassifiers from a custom pre-built PyTorch model. + + Use this when the standard ``TextClassificationModel`` (built automatically + from ``ModelConfig``) cannot express your architecture — for example when you + need multiple classification heads, shared encoders across tasks, or any other + custom topology. The wrapper then provides the usual ``predict`` / ``save`` / + ``load`` interface around your model. + + **Required interface for** ``pytorch_model``: + + 1. **``forward`` signature** — the model must accept exactly these keyword + arguments (extra ``**kwargs`` are forwarded but ignored by the wrapper):: + + def forward( + self, + input_ids: torch.Tensor, # (batch, seq_len) Long + attention_mask: torch.Tensor, # (batch, seq_len) int + categorical_vars: torch.Tensor, # (batch, n_cats) Long — may be None + **kwargs, + ) -> torch.Tensor | list[torch.Tensor]: + ... + + The return value must be **raw logits** (not softmaxed). For standard + single-task classification return a tensor of shape + ``(batch, num_classes)``. For multi-task classification you may return a + list of such tensors, one per task. + + 2. **``num_classes`` attribute** — must be an ``int`` (single task) or a + ``list[int]`` (multi-task, one entry per task head). + + 3. **``categorical_variable_net`` attribute** — the ``CategoricalVariableNet`` + module used by the model, or ``None`` if no categorical features are used. + The wrapper reads ``categorical_variable_net.categorical_vocabulary_sizes`` + to set up the data pipeline. + + See ``torchTextClassifiers.contrib`` for ready-made example architectures + (``MultiLevelTextClassificationModel``, ``MultiLevelCrossEntropyLoss``) that + follow this interface. + + Args: + tokenizer: A tokenizer instance for text preprocessing. + pytorch_model: A pre-built PyTorch model satisfying the interface above. + value_encoder: Optional ``ValueEncoder`` for encoding raw string (or + mixed) categorical values to integers. Build it from ``DictEncoder`` + or sklearn ``LabelEncoder`` instances and pass it here. If ``None``, + categorical columns in ``X`` must already be integer-encoded. + ragged_multilabel: Set to ``True`` for ragged multi-label targets + (variable number of labels per sample). + + Returns: + An instance of torchTextClassifiers wrapping the provided model. + """ + instance = cls.__new__(cls) + instance.tokenizer = tokenizer + instance.pytorch_model = pytorch_model + instance.num_classes = pytorch_model.num_classes + instance.categorical_var_net = cast( + Optional[CategoricalVariableNet], pytorch_model.categorical_variable_net + ) + instance.value_encoder = value_encoder + instance.ragged_multilabel = ragged_multilabel + instance._custom_model = True + instance.enable_label_attention = False + instance.categorical_vocabulary_sizes = ( + instance.categorical_var_net.categorical_vocabulary_sizes + if instance.categorical_var_net is not None + else None + ) + return instance + + def train( + self, + X_train: np.ndarray, + y_train: np.ndarray, + training_config: TrainingConfig, + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None, + verbose: bool = False, + ) -> None: + """Train the classifier using PyTorch Lightning. + + This method handles the complete training process including: + - Data validation and preprocessing + - Dataset and DataLoader creation + - PyTorch Lightning trainer setup with callbacks + - Model training with early stopping + - Best model loading after training + + Note on Checkpoints: + After training, the best model checkpoint is automatically loaded. + This checkpoint contains the full training state (model weights, + optimizer, and scheduler state). Loading uses weights_only=False + as the checkpoint is self-generated and trusted. + + Args: + X_train: Training input data + y_train: Training labels + X_val: Validation input data + y_val: Validation labels + training_config: Configuration parameters for training + verbose: Whether to print training progress information + + + Example: + + >>> training_config = TrainingConfig( + ... lr=1e-3, + ... batch_size=4, + ... num_epochs=1, + ... ) + >>> ttc.train( + ... X_train=X, + ... y_train=Y, + ... X_val=X, + ... y_val=Y, + ... training_config=training_config, + ... ) + """ + + # Input validation + X_train, y_train = self._check_XY( + X_train, y_train, training_config.raw_categorical_inputs, training_config.raw_labels + ) + + if X_val is not None: + assert y_val is not None, "y_val must be provided if X_val is provided." + if y_val is not None: + assert X_val is not None, "X_val must be provided if y_val is provided." + + X_val: Optional[Dict[str, Any]] = None + if X_val is not None and y_val is not None: + X_val, y_val = self._check_XY(X_val, y_val) + + if ( + (X_train["categorical_variables"] is not None) + and (X_val is not None) + and (X_val["categorical_variables"] is not None) + ): + assert ( + X_train["categorical_variables"].ndim > 1 + and X_train["categorical_variables"].shape[1] + == X_val["categorical_variables"].shape[1] + or X_val["categorical_variables"].ndim == 1 + ), "X_train and X_val must have the same number of columns." + + if verbose: + logger.info("Starting training process...") + + if training_config.accelerator == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(training_config.accelerator) + + self.device = device + + optimizer_params = {"lr": training_config.lr} + if training_config.optimizer_params is not None: + optimizer_params.update(training_config.optimizer_params) + + if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel: + logger.warning( + "āš ļø You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks." + ) + + self.lightning_module = TextClassificationModule( + model=self.pytorch_model, + loss=training_config.loss, + optimizer=training_config.optimizer, + optimizer_params=optimizer_params, + scheduler=training_config.scheduler, + scheduler_params=training_config.scheduler_params + if training_config.scheduler_params + else {}, + scheduler_interval="epoch", + ) + + self.pytorch_model.to(self.device) + + if verbose: + logger.info(f"Running on: {device}") + + train_dataset = TextClassificationDataset( + texts=X_train["text"], + categorical_variables=X_train["categorical_variables"], # None if no cat vars + tokenizer=self.tokenizer, + labels=y_train.tolist(), + ragged_multilabel=self.ragged_multilabel, + ) + train_dataloader = train_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=True, + **training_config.dataloader_params if training_config.dataloader_params else {}, + ) + + if X_val is not None and y_val is not None: + val_dataset = TextClassificationDataset( + texts=X_val["text"], + categorical_variables=X_val["categorical_variables"], # None if no cat vars + tokenizer=self.tokenizer, + labels=y_val, + ragged_multilabel=self.ragged_multilabel, + ) + val_dataloader = val_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=False, + **training_config.dataloader_params if training_config.dataloader_params else {}, + ) + else: + val_dataloader = None + + # Setup trainer + callbacks = [ + ModelCheckpoint( + monitor="val_loss" if val_dataloader is not None else "train_loss", + save_top_k=1, + save_last=False, + mode="min", + ), + EarlyStopping( + monitor="val_loss" if val_dataloader is not None else "train_loss", + patience=training_config.patience_early_stopping, + mode="min", + ), + LearningRateMonitor(logging_interval="step"), + ] + + trainer_params = { + "accelerator": training_config.accelerator, + "callbacks": callbacks, + "max_epochs": training_config.num_epochs, + "num_sanity_val_steps": 2, + "strategy": "auto", + "log_every_n_steps": 1, + "enable_progress_bar": True, + } + + if training_config.trainer_params is not None: + trainer_params.update(training_config.trainer_params) + + trainer = pl.Trainer(**trainer_params) + + torch.cuda.empty_cache() + torch.set_float32_matmul_precision("medium") + + if verbose: + logger.info("Launching training...") + start = time.time() + + trainer.fit(self.lightning_module, train_dataloader, val_dataloader) + + if verbose: + end = time.time() + logger.info(f"Training completed in {end - start:.2f} seconds.") + + best_model_path = trainer.checkpoint_callback.best_model_path + self.checkpoint_path = best_model_path + + self.lightning_module = TextClassificationModule.load_from_checkpoint( + best_model_path, + model=self.pytorch_model, + loss=training_config.loss, + weights_only=False, # Required: checkpoint contains optimizer/scheduler state + ) + + self.pytorch_model = self.lightning_module.model.to(self.device) + + self.save_path = training_config.save_path + self.save(self.save_path) + + self.lightning_module.eval() + + def _check_XY( + self, X: np.ndarray, Y: np.ndarray, raw_categorical_inputs, raw_labels + ) -> Tuple[Dict[str, Any], np.ndarray]: + X_checked = self._check_X(X, raw_categorical_inputs) + Y_checked = self._check_Y(Y, raw_labels) + + if X_checked["text"].shape[0] != len(Y_checked): + raise ValueError("X_train and y_train must have the same number of observations.") + + return X_checked, Y_checked + + @staticmethod + def _check_text_col(X): + assert isinstance( + X, np.ndarray + ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables." + + try: + if X.ndim > 1: + text = X[:, 0].astype(str) + else: + text = X[:].astype(str) + except ValueError: + logger.error("The first column of X must be castable in string format.") + + return text + + def _check_categorical_variables( + self, X: np.ndarray, raw_categorical_inputs: bool + ) -> np.ndarray: + """Validate and encode categorical variables from X. + + If a ``value_encoder`` was provided at initialization, raw string + or mixed values are encoded to integers via that encoder. Otherwise the + categorical columns must already be integer-encodable. + + Args: + X: Full input array whose first column is text and whose remaining + columns are categorical variables. + + Returns: + Integer-encoded categorical array of shape (N, n_cat_features). + + Raises: + ValueError: If the number of categorical features does not match the + model configuration, if values exceed vocabulary bounds, or if + values cannot be cast to integers and no encoder was provided. + """ + assert self.categorical_var_net is not None + + num_cat_vars = X.shape[1] - 1 if X.ndim > 1 else 0 + + if num_cat_vars != self.categorical_var_net.num_categorical_features: + raise ValueError( + f"X must have the same number of categorical variables as the number of " + f"embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})." + ) + + if raw_categorical_inputs: + if self.value_encoder is None: + raise ValueError( + "Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers." + ) + categorical_variables = self.value_encoder.transform(X[:, 1:]).astype(int) + else: + categorical_variables = X[:, 1:].astype(int) + + for j in range(num_cat_vars): + max_cat_value = categorical_variables[:, j].max() + if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]: + raise ValueError( + f"Categorical variable at index {j} has value {max_cat_value} which exceeds " + f"the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}." + ) + + return categorical_variables + + def _check_X(self, X: np.ndarray, raw_categorical_inputs: bool) -> Dict[str, Any]: + text = self._check_text_col(X) + + categorical_variables = None + if self.categorical_var_net is not None: + categorical_variables = self._check_categorical_variables(X, raw_categorical_inputs) + + return {"text": text, "categorical_variables": categorical_variables} + + def _check_Y(self, Y, raw_labels: bool) -> np.ndarray: + if self.ragged_multilabel: + assert isinstance( + Y, list + ), "Y must be a list of lists for ragged multilabel classification." + for row in Y: + assert isinstance(row, list), "Each element of Y must be a list of labels." + + return Y + + else: + assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." + assert ( + len(Y.shape) == 1 or len(Y.shape) == 2 + ), "Y must be a numpy array of shape (N,) or (N, num_labels)." + + if raw_labels: + if self.value_encoder is None: + raise ValueError( + "Raw label encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw labels to integers." + ) + Y = self.value_encoder.transform_labels(Y) + Y = Y.astype(int) + + if isinstance(self.num_classes, list): + num_classes_arr = np.array(self.num_classes) + + print(Y, num_classes_arr) + if (Y.max(axis=0) >= num_classes_arr).any() or Y.min() < 0: + raise ValueError( + f"Y contains class labels outside the expected per-level ranges " + f"[0, {[nc - 1 for nc in self.num_classes]}]." + ) + elif Y.max() >= self.num_classes or Y.min() < 0: + raise ValueError( + f"Y contains class labels outside the range [0, {self.num_classes - 1}]." + ) + + return Y + + def predict( + self, + X_test: np.ndarray, + raw_categorical_inputs: bool = True, + top_k=1, + explain_with_label_attention: bool = False, + explain_with_captum=False, + ): + """ + Args: + X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables + top_k (int): for each sentence, return the top_k most likely predictions (default: 1) + explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) + explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) + + Returns: A dictionary containing the following fields: + - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. + - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores. + - if explain is True: + - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. + """ + + explain = explain_with_label_attention or explain_with_captum + if explain: + return_offsets_mapping = True # to be passed to the tokenizer + return_word_ids = True + if self.pytorch_model.token_embedder is None: + raise RuntimeError( + "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." + ) + else: + if explain_with_captum: + if not HAS_CAPTUM: + raise ImportError( + "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." + ) + lig = LayerIntegratedGradients( + self.pytorch_model, self.pytorch_model.token_embedder.embedding_layer + ) # initialize a Captum layer gradient integrator + if explain_with_label_attention: + if not self.enable_label_attention: + raise RuntimeError( + "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." + ) + else: + return_offsets_mapping = False + return_word_ids = False + + X_test = self._check_X(X_test, raw_categorical_inputs) + text = X_test["text"] + categorical_variables = X_test["categorical_variables"] + + self.pytorch_model.eval().cpu() + + tokenize_output = self.tokenizer.tokenize( + text.tolist(), + return_offsets_mapping=return_offsets_mapping, + return_word_ids=return_word_ids, + ) + + if not isinstance(tokenize_output, TokenizerOutput): + raise TypeError( + f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method." + ) + + encoded_text = tokenize_output.input_ids # (batch_size, seq_len) + attention_mask = tokenize_output.attention_mask # (batch_size, seq_len) + + if categorical_variables is not None: + categorical_vars = torch.tensor( + categorical_variables, dtype=torch.float32 + ) # (batch_size, num_categorical_features) + else: + categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) + + model_output = self.pytorch_model( + encoded_text, + attention_mask, + categorical_vars, + return_label_attention_matrix=explain_with_label_attention, + ) # forward pass, contains the prediction scores (len(text), num_classes) + + # Multilevel custom model: returns a list of per-level logit tensors. + # Each level may have a different number of classes, so we process them + # separately then stack into (N, top_k, n_levels) before decoding. + if isinstance(model_output, list): + int_preds_per_level: List[np.ndarray] = [] + conf_per_level: List[torch.Tensor] = [] + for level_logits in cast(List[torch.Tensor], model_output): + scores = level_logits.detach().cpu().softmax(dim=-1) + level_topk = torch.topk(scores, k=top_k, dim=-1) + int_preds_per_level.append(level_topk.indices.numpy()) # (N, top_k) + conf_per_level.append(level_topk.values) # (N, top_k) + + # (N, top_k, n_levels) + int_preds_stacked = np.stack(int_preds_per_level, axis=-1) + conf_stacked = torch.stack(conf_per_level, dim=-1) + + if self.value_encoder is not None: + predictions = self.value_encoder.inverse_transform_labels(int_preds_stacked) + else: + predictions = int_preds_stacked + + return { + "prediction": predictions, + "confidence": torch.round(conf_stacked, decimals=2), + } + + pred = ( + model_output["logits"] if explain_with_label_attention else model_output + ) # (batch_size, num_classes) + + label_attention_matrix = ( + model_output["label_attention_matrix"] if explain_with_label_attention else None + ) + + label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities + + label_scores_topk = torch.topk(label_scores, k=top_k, dim=1) + + integer_predictions = label_scores_topk.indices # integer class indices (needed for captum) + if self.value_encoder is not None: + predictions = self.value_encoder.inverse_transform_labels(integer_predictions.numpy()) + else: + predictions = integer_predictions + + confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores + + if explain: + if explain_with_captum: + # Captum explanations + captum_attributions = [] + for k in range(top_k): + attributions = lig.attribute( + (encoded_text, attention_mask, categorical_vars), + target=integer_predictions[:, k], + ) # (batch_size, seq_len) + attributions = attributions.sum(dim=-1) + captum_attributions.append(attributions.detach().cpu()) + + captum_attributions = torch.stack( + captum_attributions, dim=1 + ) # (batch_size, top_k, seq_len) + else: + captum_attributions = None + + return { + "prediction": predictions, + "confidence": confidence, + "captum_attributions": captum_attributions, + "label_attention_attributions": label_attention_matrix, + "offset_mapping": tokenize_output.offset_mapping, + "word_ids": tokenize_output.word_ids, + } + else: + return { + "prediction": predictions, + "confidence": confidence, + } + + def save(self, path: Union[str, Path]) -> None: + """Save the complete torchTextClassifiers instance to disk. + + This saves: + - Model configuration + - Tokenizer state + - PyTorch Lightning checkpoint (if trained) + - All other instance attributes + + Args: + path: Directory path where the model will be saved + + Example: + >>> ttc = torchTextClassifiers(tokenizer, model_config) + >>> ttc.train(X_train, y_train, training_config) + >>> ttc.save("my_model") + """ + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + is_custom_model = getattr(self, "_custom_model", False) + + # Custom models: save architecture as pickle + weights as state_dict. + # Standard models: save a full PyTorch Lightning checkpoint. + checkpoint_path = None + if is_custom_model: + with open(path / "pytorch_model.pkl", "wb") as f: + pickle.dump(self.pytorch_model, f) + torch.save(self.pytorch_model.state_dict(), path / "model_weights.pt") + elif hasattr(self, "lightning_module"): + checkpoint_path = path / "model_checkpoint.ckpt" + trainer = pl.Trainer() + trainer.strategy.connect(self.lightning_module) + trainer.save_checkpoint(checkpoint_path) + + metadata: Dict[str, Any] = { + "is_custom_model": is_custom_model, + "loss": self.lightning_module.loss if hasattr(self, "lightning_module") else None, + "ragged_multilabel": self.ragged_multilabel, + "num_classes": self.num_classes, + "checkpoint_path": str(checkpoint_path) if checkpoint_path else None, + "device": str(self.device) if hasattr(self, "device") else None, + "has_value_encoder": self.value_encoder is not None, + } + + if not is_custom_model: + metadata.update( + { + "model_config": self.model_config.to_dict(), + "vocab_size": self.vocab_size, + "embedding_dim": self.embedding_dim, + "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes, + } + ) + + with open(path / "metadata.pkl", "wb") as f: + pickle.dump(metadata, f) + + tokenizer_path = path / "tokenizer.pkl" + with open(tokenizer_path, "wb") as f: + pickle.dump(self.tokenizer, f) + + if self.value_encoder is not None: + with open(path / "value_encoder.pkl", "wb") as f: + pickle.dump(self.value_encoder, f) + + logger.info(f"Model saved successfully to {path}") + + @classmethod + def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers": + """Load a torchTextClassifiers instance from disk. + + Args: + path: Directory path where the model was saved + device: Device to load the model on ('auto', 'cpu', 'cuda', etc.) + + Returns: + Loaded torchTextClassifiers instance + + Example: + >>> loaded_ttc = torchTextClassifiers.load("my_model") + >>> predictions = loaded_ttc.predict(X_test) + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Model directory not found: {path}") + + with open(path / "metadata.pkl", "rb") as f: + metadata = pickle.load(f) + + with open(path / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + + resolved_device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "auto" + else torch.device(device) + ) + + value_encoder = None + if metadata.get("has_value_encoder"): + encoder_path = path / "value_encoder.pkl" + if encoder_path.exists(): + with open(encoder_path, "rb") as f: + value_encoder = pickle.load(f) + + if metadata.get("is_custom_model", False): + with open(path / "pytorch_model.pkl", "rb") as f: + pytorch_model = pickle.load(f) + weights_path = path / "model_weights.pt" + if weights_path.exists(): + pytorch_model.load_state_dict(torch.load(weights_path, weights_only=True)) + logger.info(f"Model weights loaded from {weights_path}") + instance = cls.from_model( + tokenizer=tokenizer, + pytorch_model=pytorch_model, + value_encoder=value_encoder, + ragged_multilabel=metadata["ragged_multilabel"], + ) + instance.device = resolved_device + pytorch_model.to(resolved_device) + logger.info(f"Model loaded successfully from {path}") + return instance + + model_config = ModelConfig.from_dict(metadata["model_config"]) + + instance = cls( + tokenizer=tokenizer, + model_config=model_config, + ragged_multilabel=metadata["ragged_multilabel"], + value_encoder=value_encoder, + ) + + instance.device = resolved_device + + if metadata.get("checkpoint_path"): + checkpoint_path = path / "model_checkpoint.ckpt" + if checkpoint_path.exists(): + loss = metadata.get("loss") or torch.nn.CrossEntropyLoss() + instance.lightning_module = TextClassificationModule.load_from_checkpoint( + str(checkpoint_path), + model=instance.pytorch_model, + loss=loss, + weights_only=False, + ) + instance.pytorch_model = instance.lightning_module.model.to(resolved_device) + instance.checkpoint_path = str(checkpoint_path) + logger.info(f"Model checkpoint loaded from {checkpoint_path}") + else: + logger.warning(f"Checkpoint file not found at {checkpoint_path}") + + logger.info(f"Model loaded successfully from {path}") + return instance + + def __repr__(self): + model_type = ( + self.lightning_module.__repr__() + if hasattr(self, "lightning_module") + else self.pytorch_model.__repr__() + ) + + tokenizer_info = self.tokenizer.__repr__() + + cat_forward_type = ( + self.categorical_var_net.forward_type.name + if self.categorical_var_net is not None + else "None" + ) + + lines = [ + "torchTextClassifiers(", + f" tokenizer = {tokenizer_info},", + f" model = {model_type},", + f" categorical_forward_type = {cat_forward_type},", + f" num_classes = {self.model_config.num_classes},", + f" embedding_dim = {self.embedding_dim},", + ")", + ] + return "\n".join(lines) diff --git a/torchTextClassifiers/value_encoder/value_encoder.py b/torchTextClassifiers/value_encoder/value_encoder.py index 8215093..7ca946b 100644 --- a/torchTextClassifiers/value_encoder/value_encoder.py +++ b/torchTextClassifiers/value_encoder/value_encoder.py @@ -56,14 +56,21 @@ class ValueEncoder: def __init__( self, - label_encoder: DictEncoder | LabelEncoder, + label_encoder: DictEncoder | LabelEncoder | list[DictEncoder | LabelEncoder], categorical_encoders: Optional[dict[str, DictEncoder | LabelEncoder]] = None, ): self.categorical_encoders = categorical_encoders - if not isinstance(label_encoder, (DictEncoder, LabelEncoder)): + if isinstance(label_encoder, list): + for enc in label_encoder: + if not isinstance(enc, (DictEncoder, LabelEncoder)): + raise TypeError( + "Each element of label_encoder list must be a DictEncoder or LabelEncoder, " + f"got {type(enc)}" + ) + elif not isinstance(label_encoder, (DictEncoder, LabelEncoder)): raise TypeError( - "label_encoder must be a DictEncoder or LabelEncoder instance, " + "label_encoder must be a DictEncoder, LabelEncoder, or list thereof, " f"got {type(label_encoder)}" ) self.label_encoder = label_encoder @@ -88,10 +95,21 @@ def vocabulary_sizes(self) -> list[int]: @property def num_classes(self) -> int: """Number of unique classes in the label encoder, if provided.""" - if isinstance(self.label_encoder, DictEncoder): - return len(self.label_encoder.mapping) - elif hasattr(self.label_encoder, "classes_"): - return len(self.label_encoder.classes_) + + def _get_num_classes(enc): + if isinstance(enc, DictEncoder): + return len(enc.mapping) + elif hasattr(enc, "classes_"): + return len(enc.classes_) + else: + raise TypeError(f"Unsupported encoder type: {type(enc)}") + + if isinstance(self.label_encoder, DictEncoder) or isinstance( + self.label_encoder, LabelEncoder + ): + return _get_num_classes(self.label_encoder) + elif isinstance(self.label_encoder, list): + return [_get_num_classes(enc) for enc in self.label_encoder] else: raise TypeError(f"Unsupported label encoder type: {type(self.label_encoder)}") @@ -137,38 +155,85 @@ def transform_labels(self, y_labels: np.ndarray) -> np.ndarray: Values are converted to strings before lookup. Unknown values raise a ValueError. Args: - y_labels: Array of shape (N,) with label values. + y_labels: Array of shape (N,) for a single encoder, or (N, n_levels) for a list + of encoders. Returns: - Integer-encoded array of shape (N,), dtype int64. + Integer-encoded array of shape (N,) or (N, n_levels), dtype int64. Raises: ValueError: If any label value was not seen during fitting. """ - col = y_labels.astype(str) - encoded = self.label_encoder.transform(col) - try: - return encoded.astype(np.int64) - except (TypeError, ValueError): - unknown = [v for v, e in zip(col.tolist(), encoded.tolist()) if e is None] - raise ValueError( - f"Unknown values in label encoder: {unknown}. " - "These values were not seen during fitting." - ) + def _encode_col(enc, col, level_name="label encoder"): + encoded = enc.transform(col) + try: + return encoded.astype(np.int64) + except (TypeError, ValueError): + unknown = [v for v, e in zip(col.tolist(), encoded.tolist()) if e is None] + raise ValueError( + f"Unknown values in {level_name}: {unknown}. " + "These values were not seen during fitting." + ) + + if isinstance(self.label_encoder, list): + if y_labels.ndim == 1: + y_labels = y_labels.reshape(-1, 1) + result = np.empty(y_labels.shape, dtype=np.int64) + for idx, enc in enumerate(self.label_encoder): + result[:, idx] = _encode_col( + enc, y_labels[:, idx].astype(str), f"label encoder at level {idx}" + ) + return result + + return _encode_col(self.label_encoder, y_labels.astype(str)) def inverse_transform_labels(self, y_encoded: np.ndarray) -> np.ndarray: """Decode integer-encoded labels back to original values. Args: - y_encoded: Array of shape (N,) with integer-encoded labels. + y_encoded: Supported shapes: + - Single encoder: (N,) or (N, top_k) + - List of encoders: (N, n_levels), or (N, top_k, n_levels) Returns: - Array of shape (N,) with original label values. + Decoded array with the same shape as ``y_encoded``. Raises: ValueError: If any encoded label value was not seen during fitting. """ + def _decode_col(enc, col: np.ndarray) -> np.ndarray: + """Decode a flat 1-D array using a single encoder.""" + if isinstance(enc, DictEncoder): + return np.vectorize(enc.inverse_mapping.get)(col) + elif hasattr(enc, "inverse_transform"): + return enc.inverse_transform(col) + else: + raise TypeError(f"Unsupported label encoder type: {type(enc)}") + + if isinstance(self.label_encoder, list): + if y_encoded.ndim == 1: + y_encoded = y_encoded.reshape(-1, 1) + + if y_encoded.ndim == 2: + # (N, n_levels) + result = np.empty(y_encoded.shape, dtype=object) + for idx, enc in enumerate(self.label_encoder): + result[:, idx] = _decode_col(enc, y_encoded[:, idx]) + return result + + if y_encoded.ndim == 3: + # (N, top_k, n_levels) + n, top_k = y_encoded.shape[0], y_encoded.shape[1] + result = np.empty(y_encoded.shape, dtype=object) + for idx, enc in enumerate(self.label_encoder): + flat = y_encoded[:, :, idx].ravel() + result[:, :, idx] = _decode_col(enc, flat).reshape(n, top_k) + return result + + raise ValueError( + f"Expected 1-D, 2-D, or 3-D array for list encoder, got {y_encoded.ndim}-D" + ) + if isinstance(self.label_encoder, DictEncoder): - inverse_mapping = self.label_encoder.inverse_mapping - return np.vectorize(inverse_mapping.get)(y_encoded) + return np.vectorize(self.label_encoder.inverse_mapping.get)(y_encoded) elif hasattr(self.label_encoder, "inverse_transform"): shape = y_encoded.shape result = self.label_encoder.inverse_transform(y_encoded.ravel())