From b37aa657ca60dc34706a6b0cafcf9ca7a42ca073 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Fri, 24 Apr 2026 11:42:51 +0000 Subject: [PATCH 1/6] chore: ruff format --- torchTextClassifiers/model/components/__init__.py | 14 ++++++++++++-- .../model/components/classification_head.py | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/components/__init__.py b/torchTextClassifiers/model/components/__init__.py index 3db1a73..07944f0 100644 --- a/torchTextClassifiers/model/components/__init__.py +++ b/torchTextClassifiers/model/components/__init__.py @@ -9,5 +9,15 @@ ) from .classification_head import ClassificationHead as ClassificationHead from .text_embedder import LabelAttentionConfig as LabelAttentionConfig -from .text_embedder import TokenEmbedder as TokenEmbedder, TokenEmbedderConfig as TokenEmbedderConfig -from .text_embedder import SentenceEmbedder as SentenceEmbedder, SentenceEmbedderConfig as SentenceEmbedderConfig \ No newline at end of file +from .text_embedder import ( + SentenceEmbedder as SentenceEmbedder, +) +from .text_embedder import ( + SentenceEmbedderConfig as SentenceEmbedderConfig, +) +from .text_embedder import ( + TokenEmbedder as TokenEmbedder, +) +from .text_embedder import ( + TokenEmbedderConfig as TokenEmbedderConfig, +) diff --git a/torchTextClassifiers/model/components/classification_head.py b/torchTextClassifiers/model/components/classification_head.py index bd21f52..475957c 100644 --- a/torchTextClassifiers/model/components/classification_head.py +++ b/torchTextClassifiers/model/components/classification_head.py @@ -25,7 +25,7 @@ def __init__( super().__init__() if net is not None: self.net = net - + # --- Custom net should either be a Sequential or a Linear --- if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)): raise ValueError("net must be an nn.Sequential when provided.") From efb626e49132ffce231e00dcaf46bda65e75f11e Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Fri, 24 Apr 2026 11:45:32 +0000 Subject: [PATCH 2/6] chore: add multilevel example + ruff format --- examples/advanced_training.py | 187 ++++++------- examples/basic_classification.py | 164 ++++++----- examples/multiclass_classification.py | 155 ++++++----- examples/multilevel_example.py | 317 ++++++++++++++++++++++ examples/simple_explainability_example.py | 224 ++++++++------- examples/using_additional_features.py | 115 ++++---- 6 files changed, 747 insertions(+), 415 deletions(-) create mode 100644 examples/multilevel_example.py diff --git a/examples/advanced_training.py b/examples/advanced_training.py index 03c48a7..82b5364 100644 --- a/examples/advanced_training.py +++ b/examples/advanced_training.py @@ -7,7 +7,6 @@ """ import os -import random import warnings import numpy as np @@ -17,13 +16,14 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -35,10 +35,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("āš™ļø Advanced Training Configuration Example") @@ -46,7 +43,7 @@ def main(): # Create a larger dataset for demonstrating advanced training print("šŸ“ Creating training dataset...") - + # Generate more diverse training data positive_samples = [ "Excellent product with outstanding quality and performance.", @@ -58,9 +55,9 @@ def main(): "Fantastic features and user-friendly interface design provided.", "Outstanding durability and reliability in daily usage.", "Impressive performance and excellent build quality throughout.", - "Wonderful experience from purchase to delivery service." + "Wonderful experience from purchase to delivery service.", ] - + negative_samples = [ "Terrible product quality, completely disappointed with purchase.", "Poor customer service and slow delivery times experienced.", @@ -71,30 +68,34 @@ def main(): "Horrible user experience and confusing interface design.", "Broke after few days of normal usage.", "Poor value for money, better alternatives available.", - "Disappointing performance and unreliable functionality shown." + "Disappointing performance and unreliable functionality shown.", ] - + # Combine and create arrays X_train = np.array(positive_samples + negative_samples) y_train = np.array([1] * len(positive_samples) + [0] * len(negative_samples)) - + # Validation data - X_val = np.array([ - "Good product with decent quality for the price.", - "Not satisfied with the purchase, poor quality.", - "Excellent service and great product quality.", - "Disappointed with the product performance results." - ]) + X_val = np.array( + [ + "Good product with decent quality for the price.", + "Not satisfied with the purchase, poor quality.", + "Excellent service and great product quality.", + "Disappointed with the product performance results.", + ] + ) y_val = np.array([1, 0, 1, 0]) - + # Test data - X_test = np.array([ - "Outstanding product with amazing features!", - "Terrible quality, complete waste of money.", - "Great value and excellent customer support." - ]) + X_test = np.array( + [ + "Outstanding product with amazing features!", + "Terrible quality, complete waste of money.", + "Great value and excellent customer support.", + ] + ) y_test = np.array([1, 0, 1]) - + print(f"Training samples: {len(X_train)}") print(f"Validation samples: {len(X_val)}") print(f"Test samples: {len(X_test)}") @@ -109,15 +110,9 @@ def main(): # Example 1: Basic training with default settings print("\nšŸŽÆ Example 1: Basic training with default settings...") - model_config = ModelConfig( - embedding_dim=100, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=100, num_classes=2) - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") training_config = TrainingConfig( @@ -126,45 +121,38 @@ def main(): lr=1e-3, patience_early_stopping=5, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) result = classifier.predict(X_test) basic_predictions = result["prediction"].squeeze().numpy() basic_accuracy = (basic_predictions == y_test).mean() print(f"āœ… Basic training completed! Accuracy: {basic_accuracy:.3f}") - + # Example 2: Advanced training with custom Lightning trainer parameters print("\nšŸš€ Example 2: Advanced training with custom parameters...") # Create a new classifier for comparison - advanced_model_config = ModelConfig( - embedding_dim=100, - num_classes=2 - ) + advanced_model_config = ModelConfig(embedding_dim=100, num_classes=2) advanced_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=advanced_model_config + tokenizer=tokenizer, model_config=advanced_model_config ) print("āœ… Advanced classifier created successfully!") # Custom trainer parameters for advanced features advanced_trainer_params = { - 'accelerator': 'auto', # Use GPU if available, else CPU - 'precision': 32, # Use 32-bit precision - 'gradient_clip_val': 1.0, # Gradient clipping - 'accumulate_grad_batches': 2, # Gradient accumulation - 'deterministic': True, # For reproducible results - 'enable_progress_bar': True, # Show progress bar - 'log_every_n_steps': 5, # Log every 5 steps + "accelerator": "auto", # Use GPU if available, else CPU + "precision": 32, # Use 32-bit precision + "gradient_clip_val": 1.0, # Gradient clipping + "accumulate_grad_batches": 2, # Gradient accumulation + "deterministic": True, # For reproducible results + "enable_progress_bar": True, # Show progress bar + "log_every_n_steps": 5, # Log every 5 steps } advanced_training_config = TrainingConfig( @@ -173,33 +161,32 @@ def main(): lr=1e-3, patience_early_stopping=7, num_workers=0, - trainer_params=advanced_trainer_params + trainer_params=advanced_trainer_params, ) advanced_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=advanced_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) advanced_result = advanced_classifier.predict(X_test) advanced_predictions = advanced_result["prediction"].squeeze().numpy() advanced_accuracy = (advanced_predictions == y_test).mean() print(f"āœ… Advanced training completed! Accuracy: {advanced_accuracy:.3f}") - + # Example 3: Training with CPU-only (useful for small datasets or debugging) print("\nšŸ’» Example 3: CPU-only training...") cpu_model_config = ModelConfig( embedding_dim=64, # Smaller embedding for faster CPU training - num_classes=2 + num_classes=2, ) - cpu_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=cpu_model_config - ) + cpu_classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=cpu_model_config) print("āœ… CPU classifier created successfully!") cpu_training_config = TrainingConfig( @@ -208,44 +195,40 @@ def main(): lr=1e-3, patience_early_stopping=3, num_workers=0, # No multiprocessing for CPU - trainer_params={'deterministic': True, 'accelerator': 'cpu'} + trainer_params={"deterministic": True, "accelerator": "cpu"}, ) cpu_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=cpu_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) cpu_result = cpu_classifier.predict(X_test) cpu_predictions = cpu_result["prediction"].squeeze().numpy() cpu_accuracy = (cpu_predictions == y_test).mean() print(f"āœ… CPU training completed! Accuracy: {cpu_accuracy:.3f}") - + # Example 4: Custom training with specific Lightning callbacks print("\nšŸ”§ Example 4: Training with custom callbacks...") - custom_model_config = ModelConfig( - embedding_dim=128, - num_classes=2 - ) + custom_model_config = ModelConfig(embedding_dim=128, num_classes=2) - custom_classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=custom_model_config - ) + custom_classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=custom_model_config) print("āœ… Custom classifier created successfully!") # Custom trainer with specific monitoring and checkpointing custom_trainer_params = { - 'max_epochs': 25, - 'enable_progress_bar': True, - 'log_every_n_steps': 1, - 'check_val_every_n_epoch': 2, # Validate every 2 epochs - 'enable_checkpointing': True, - 'enable_model_summary': True, - 'deterministic': True, + "max_epochs": 25, + "enable_progress_bar": True, + "log_every_n_steps": 1, + "check_val_every_n_epoch": 2, # Validate every 2 epochs + "enable_checkpointing": True, + "enable_model_summary": True, + "deterministic": True, } custom_training_config = TrainingConfig( @@ -254,21 +237,23 @@ def main(): lr=1e-3, patience_early_stopping=8, num_workers=0, - trainer_params=custom_trainer_params + trainer_params=custom_trainer_params, ) custom_classifier.train( - X_train, y_train, + X_train, + y_train, training_config=custom_training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_val=X_val, + y_val=y_val, + verbose=True, ) custom_result = custom_classifier.predict(X_test) custom_predictions = custom_result["prediction"].squeeze().numpy() custom_accuracy = (custom_predictions == y_test).mean() print(f"āœ… Custom training completed! Accuracy: {custom_accuracy:.3f}") - + # Compare all training approaches print("\nšŸ“Š Training Comparison Results:") print("-" * 50) @@ -276,35 +261,35 @@ def main(): print(f"Advanced training: {advanced_accuracy:.3f}") print(f"CPU-only training: {cpu_accuracy:.3f}") print(f"Custom training: {custom_accuracy:.3f}") - + # Find best performing model results = { - 'Basic': basic_accuracy, - 'Advanced': advanced_accuracy, - 'CPU-only': cpu_accuracy, - 'Custom': custom_accuracy + "Basic": basic_accuracy, + "Advanced": advanced_accuracy, + "CPU-only": cpu_accuracy, + "Custom": custom_accuracy, } best_method = max(results, key=results.get) print(f"\nšŸ† Best performing method: {best_method} (Accuracy: {results[best_method]:.3f})") - + # Demonstrate prediction with best model print(f"\nšŸ”® Making predictions with {best_method.lower()} model...") best_classifier = { - 'Basic': classifier, - 'Advanced': advanced_classifier, - 'CPU-only': cpu_classifier, - 'Custom': custom_classifier + "Basic": classifier, + "Advanced": advanced_classifier, + "CPU-only": cpu_classifier, + "Custom": custom_classifier, }[best_method] - + predictions = best_classifier.predict(X_test) print("Test predictions:") for i, (text, pred, true) in enumerate(zip(X_test, predictions, y_test)): sentiment = "Positive" if pred == 1 else "Negative" correct = "āœ…" if pred == true else "āŒ" - print(f"{i+1}. {correct} {sentiment}: {text[:50]}...") + print(f"{i + 1}. {correct} {sentiment}: {text[:50]}...") print("\nšŸŽ‰ Advanced training example completed successfully!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/basic_classification.py b/examples/basic_classification.py index 695115d..9da2362 100644 --- a/examples/basic_classification.py +++ b/examples/basic_classification.py @@ -5,12 +5,10 @@ text classification using the Wrapper. """ -import os -import random import warnings import numpy as np -import torch + from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer @@ -18,10 +16,7 @@ def main(): # Suppress PyTorch Lightning batch_size inference warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸš€ Basic Text Classification Example") @@ -29,77 +24,85 @@ def main(): # Create sample data print("šŸ“ Creating sample data...") - X_train = np.array([ - "I love this product! It's amazing and works perfectly.", - "This is terrible. Worst purchase ever made.", - "Great quality and fast shipping. Highly recommend!", - "Poor quality, broke after one day. Very disappointed.", - "Excellent customer service and great value for money.", - "Overpriced and doesn't work as advertised.", - "Perfect! Exactly what I was looking for.", - "Waste of money. Should have read reviews first.", - "Outstanding product with excellent build quality.", - "Cheap plastic, feels like it will break soon.", - "Absolutely fantastic! Exceeded all my expectations.", - "Horrible experience. Customer service was rude and unhelpful.", - "Best purchase I've made this year. Five stars!", - "Defective item arrived. Packaging was also damaged.", - "Super impressed with the performance and durability.", - "Total disappointment. Doesn't match the description at all.", - "Wonderful product! My whole family loves it.", - "Avoid at all costs. Complete waste of time and money.", - "Remarkable quality for the price. Very satisfied!", - "Broke within a week. Clearly poor manufacturing.", - "Exceptional value! Would definitely buy again.", - "Misleading photos. Product looks nothing like advertised.", - "Works like a charm. Installation was easy too.", - "Returned it immediately. Not worth even half the price.", - "Beautiful design and sturdy construction. Love it!", - "Arrived late and damaged. Very frustrating experience.", - "Top-notch quality! Highly recommend to everyone.", - "Uncomfortable and poorly made. Regret buying this.", - "Perfect fit and great finish. Couldn't be happier!", - "Stopped working after two uses. Complete junk." - ]) - - y_train = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) # 1=positive, 0=negative - + X_train = np.array( + [ + "I love this product! It's amazing and works perfectly.", + "This is terrible. Worst purchase ever made.", + "Great quality and fast shipping. Highly recommend!", + "Poor quality, broke after one day. Very disappointed.", + "Excellent customer service and great value for money.", + "Overpriced and doesn't work as advertised.", + "Perfect! Exactly what I was looking for.", + "Waste of money. Should have read reviews first.", + "Outstanding product with excellent build quality.", + "Cheap plastic, feels like it will break soon.", + "Absolutely fantastic! Exceeded all my expectations.", + "Horrible experience. Customer service was rude and unhelpful.", + "Best purchase I've made this year. Five stars!", + "Defective item arrived. Packaging was also damaged.", + "Super impressed with the performance and durability.", + "Total disappointment. Doesn't match the description at all.", + "Wonderful product! My whole family loves it.", + "Avoid at all costs. Complete waste of time and money.", + "Remarkable quality for the price. Very satisfied!", + "Broke within a week. Clearly poor manufacturing.", + "Exceptional value! Would definitely buy again.", + "Misleading photos. Product looks nothing like advertised.", + "Works like a charm. Installation was easy too.", + "Returned it immediately. Not worth even half the price.", + "Beautiful design and sturdy construction. Love it!", + "Arrived late and damaged. Very frustrating experience.", + "Top-notch quality! Highly recommend to everyone.", + "Uncomfortable and poorly made. Regret buying this.", + "Perfect fit and great finish. Couldn't be happier!", + "Stopped working after two uses. Complete junk.", + ] + ) + + y_train = np.array( + [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0] + ) # 1=positive, 0=negative + # Validation data - X_val = np.array([ - "Good product, satisfied with purchase.", - "Not worth the money, poor quality.", - "Really happy with this purchase. Great item!", - "Disappointed with the quality. Expected better.", - "Solid product that does what it promises.", - "Don't waste your money on this. Very poor.", - "Impressive quality and quick delivery.", - "Malfunctioned right out of the box. Terrible." - ]) + X_val = np.array( + [ + "Good product, satisfied with purchase.", + "Not worth the money, poor quality.", + "Really happy with this purchase. Great item!", + "Disappointed with the quality. Expected better.", + "Solid product that does what it promises.", + "Don't waste your money on this. Very poor.", + "Impressive quality and quick delivery.", + "Malfunctioned right out of the box. Terrible.", + ] + ) y_val = np.array([1, 0, 1, 0, 1, 0, 1, 0]) - + # Test data - X_test = np.array([ - "This is an amazing product with great features!", - "Completely disappointed with this purchase.", - "Excellent build quality and works as expected.", - "Not recommended. Had issues from day one.", - "Fantastic product! Worth every penny.", - "Failed to meet basic expectations. Very poor.", - "Love it! Exactly as described and high quality.", - "Cheap materials and sloppy construction. Avoid.", - "Superb performance and easy to use. Highly satisfied!", - "Unreliable and frustrating. Should have bought elsewhere." - ]) + X_test = np.array( + [ + "This is an amazing product with great features!", + "Completely disappointed with this purchase.", + "Excellent build quality and works as expected.", + "Not recommended. Had issues from day one.", + "Fantastic product! Worth every penny.", + "Failed to meet basic expectations. Very poor.", + "Love it! Exactly as described and high quality.", + "Cheap materials and sloppy construction. Avoid.", + "Superb performance and easy to use. Highly satisfied!", + "Unreliable and frustrating. Should have bought elsewhere.", + ] + ) y_test = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) - + print(f"Training samples: {len(X_train)}") print(f"Validation samples: {len(X_val)}") print(f"Test samples: {len(X_test)}") - + # Create and train tokenizer print("\nšŸ—ļø Creating and training WordPiece tokenizer...") tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128) - + # Train tokenizer on the training corpus training_corpus = X_train.tolist() tokenizer.train(training_corpus) @@ -107,17 +110,11 @@ def main(): # Create model configuration print("\nšŸ”§ Creating model configuration...") - model_config = ModelConfig( - embedding_dim=50, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=2) # Create classifier print("\nšŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") print(classifier) # Train the model @@ -130,13 +127,10 @@ def main(): num_workers=0, # Use 0 for simple examples to avoid multiprocessing issues ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Make predictions print("\nšŸ”® Making predictions...") result = classifier.predict(X_test) @@ -149,19 +143,19 @@ def main(): # Calculate accuracy accuracy = (predictions == y_test).mean() print(f"Test accuracy: {accuracy:.3f}") - + # Show detailed results print("\nšŸ“Š Detailed Results:") print("-" * 40) for i, (text, pred, true) in enumerate(zip(X_test, predictions, y_test)): sentiment = "Positive" if pred == 1 else "Negative" correct = "āœ…" if pred == true else "āŒ" - print(f"{i+1}. {correct} Predicted: {sentiment}") + print(f"{i + 1}. {correct} Predicted: {sentiment}") print(f" Text: {text[:50]}...") print() - + print("\nšŸŽ‰ Example completed successfully!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/multiclass_classification.py b/examples/multiclass_classification.py index 92e5d48..e50063f 100644 --- a/examples/multiclass_classification.py +++ b/examples/multiclass_classification.py @@ -7,7 +7,6 @@ """ import os -import random import warnings import numpy as np @@ -17,13 +16,14 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -35,10 +35,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸŽ­ Multi-class Text Classification Example") @@ -46,54 +43,76 @@ def main(): # Create multi-class sample data (3 classes: 0=negative, 1=neutral, 2=positive) print("šŸ“ Creating multi-class sentiment data...") - X_train = np.array([ - # Negative examples (class 0) - "This product is terrible and I hate it completely.", - "Worst purchase ever. Total waste of money.", - "Absolutely awful quality. Very disappointed.", - "Poor service and terrible product quality.", - "I regret buying this. Complete failure.", - - # Neutral examples (class 1) - "The product is okay, nothing special though.", - "It works but could be better designed.", - "Average quality for the price point.", - "Not bad but not great either.", - "It's fine, meets basic expectations.", - - # Positive examples (class 2) - "Excellent product! Highly recommended!", - "Amazing quality and great customer service.", - "Perfect! Exactly what I was looking for.", - "Outstanding value and excellent performance.", - "Love it! Will definitely buy again." - ]) - - y_train = np.array([0, 0, 0, 0, 0, # negative - 1, 1, 1, 1, 1, # neutral - 2, 2, 2, 2, 2]) # positive - + X_train = np.array( + [ + # Negative examples (class 0) + "This product is terrible and I hate it completely.", + "Worst purchase ever. Total waste of money.", + "Absolutely awful quality. Very disappointed.", + "Poor service and terrible product quality.", + "I regret buying this. Complete failure.", + # Neutral examples (class 1) + "The product is okay, nothing special though.", + "It works but could be better designed.", + "Average quality for the price point.", + "Not bad but not great either.", + "It's fine, meets basic expectations.", + # Positive examples (class 2) + "Excellent product! Highly recommended!", + "Amazing quality and great customer service.", + "Perfect! Exactly what I was looking for.", + "Outstanding value and excellent performance.", + "Love it! Will definitely buy again.", + ] + ) + + y_train = np.array( + [ + 0, + 0, + 0, + 0, + 0, # negative + 1, + 1, + 1, + 1, + 1, # neutral + 2, + 2, + 2, + 2, + 2, + ] + ) # positive + # Validation data - X_val = np.array([ - "Bad quality, not recommended.", # negative - "It's okay, does the job.", # neutral - "Great product, very satisfied!" # positive - ]) + X_val = np.array( + [ + "Bad quality, not recommended.", # negative + "It's okay, does the job.", # neutral + "Great product, very satisfied!", # positive + ] + ) y_val = np.array([0, 1, 2]) - + # Test data - X_test = np.array([ - "This is absolutely horrible!", - "It's an average product, nothing more.", - "Fantastic! Love every aspect of it!", - "Really poor design and quality.", - "Works well, good value for money.", - "Outstanding product with amazing features!" - ]) + X_test = np.array( + [ + "This is absolutely horrible!", + "It's an average product, nothing more.", + "Fantastic! Love every aspect of it!", + "Really poor design and quality.", + "Works well, good value for money.", + "Outstanding product with amazing features!", + ] + ) y_test = np.array([0, 1, 2, 0, 1, 2]) - + print(f"Training samples: {len(X_train)}") - print(f"Class distribution: Negative={sum(y_train==0)}, Neutral={sum(y_train==1)}, Positive={sum(y_train==2)}") + print( + f"Class distribution: Negative={sum(y_train == 0)}, Neutral={sum(y_train == 1)}, Positive={sum(y_train == 2)}" + ) # Create and train tokenizer print("\nšŸ—ļø Creating and training WordPiece tokenizer...") @@ -106,15 +125,12 @@ def main(): print("\nšŸ”§ Creating model configuration...") model_config = ModelConfig( embedding_dim=64, - num_classes=3 # 3 classes for sentiment (negative, neutral, positive) + num_classes=3, # 3 classes for sentiment (negative, neutral, positive) ) # Create classifier print("\nšŸ”Ø Creating multi-class classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") # Train the model @@ -125,16 +141,13 @@ def main(): lr=1e-3, patience_early_stopping=7, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Make predictions print("\nšŸ”® Making predictions...") result = classifier.predict(X_test) @@ -145,10 +158,10 @@ def main(): # Calculate accuracy accuracy = (predictions == y_test).mean() print(f"Test accuracy: {accuracy:.3f}") - + # Define class names for better output class_names = ["Negative", "Neutral", "Positive"] - + # Show detailed results print("\nšŸ“Š Detailed Results:") print("-" * 60) @@ -160,15 +173,17 @@ def main(): if correct: correct_predictions += 1 status = "āœ…" if correct else "āŒ" - - print(f"{i+1}. {status} Predicted: {predicted_sentiment}, True: {true_sentiment}") + + print(f"{i + 1}. {status} Predicted: {predicted_sentiment}, True: {true_sentiment}") print(f" Text: {text}") print() - - print(f"Final Accuracy: {correct_predictions}/{len(X_test)} = {correct_predictions/len(X_test):.3f}") - - + + print( + f"Final Accuracy: {correct_predictions}/{len(X_test)} = {correct_predictions / len(X_test):.3f}" + ) + print("\nšŸŽ‰ Multi-class example completed successfully!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/multilevel_example.py b/examples/multilevel_example.py new file mode 100644 index 0000000..e04b943 --- /dev/null +++ b/examples/multilevel_example.py @@ -0,0 +1,317 @@ +from typing import Optional + +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder +from torch import nn + +import torchTextClassifiers +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.dataset import TextClassificationDataset +from torchTextClassifiers.model import TextClassificationModule +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, +) +from torchTextClassifiers.tokenizers import WordPieceTokenizer +from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder + +sample_text_data = [ + "This is a positive example", + "This is a negative example", + "Another positive case", + "Another negative case", + "Good example here", + "Bad example here", +] + +categorical_data = np.array( + [ + ["cat", "red"], + ["dog", "blue"], + ["cat", "red"], + ["dog", "blue"], + ["cat", "red"], + ["dog", "blue"], + ] +) + +labels_level_1 = np.array(["positive", "negative", "positive", "negative", "positive", "neutral"]) +labels_level_2 = np.array(["good", "bad", "good", "bad", "good", "bad"]) +labels_level3 = np.array(["A", "B", "D", "B", "C", "B"]) + + +df = pd.DataFrame( + { + "text": sample_text_data, + "category": categorical_data[:, 0], + "color": categorical_data[:, 1], + "label_level_1": labels_level_1, # You can switch to labels_level_2 or labels_level_3 for testing + "label_level_2": labels_level_2, + "label_level_3": labels_level3, + } +) +vocab_size = 10 +tokenizer = WordPieceTokenizer(vocab_size, output_dim=50) +tokenizer.train(sample_text_data) + +encoders = {} +# category : DictEncoder (ours) +feature = "category" +mapping = {val: idx for idx, val in enumerate(df[feature].unique())} +encoders[feature] = DictEncoder(mapping) + +# color: LabelEncoder (sklearn) +le = LabelEncoder() +le.fit(df["color"]) +encoders["color"] = le + +feature = "label_level_1" +le_label = LabelEncoder() +le_label.fit(df[feature]) + +feature = "label_level_2" +le_label_2 = LabelEncoder() +le_label_2.fit(df[feature]) + +feature = "label_level_3" +le_label_3 = DictEncoder({val: idx for idx, val in enumerate(df[feature].unique())}) + +label_encoder = [le_label, le_label_2, le_label_3] +# OR you can also use DictEncoder +# dict_mapping = {val: idx for idx, val in enumerate(df[feature].unique())} +# label_encoder = DictEncoder(dict_mapping) + +value_encoder = ValueEncoder(label_encoder, encoders) + + +model_config = ModelConfig( + embedding_dim=10, + categorical_embedding_dims=5, + n_heads_label_attention=2, + num_classes=value_encoder.num_classes, + attention_config=AttentionConfig(n_layers=2, n_head=5, n_kv_head=5, positional_encoding=False), + aggregation_method=None, +) +training_config = TrainingConfig( + num_epochs=1, + batch_size=6, + lr=1e-3, + raw_categorical_inputs=True, +) + +train_dataset = TextClassificationDataset( + texts=df["text"].values, + categorical_variables=value_encoder.transform( + df[["category", "color"]].values + ), # None if no cat vars + tokenizer=tokenizer, + labels=value_encoder.transform_labels( + df[["label_level_1", "label_level_2", "label_level_3"]].values + ), # None if no labels +) +train_dataloader = train_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=False, + **training_config.dataloader_params if training_config.dataloader_params else {}, +) +batch = next(iter(train_dataloader)) + + +token_embedder_config = TokenEmbedderConfig( + vocab_size=tokenizer.vocab_size, + embedding_dim=model_config.embedding_dim, + padding_idx=tokenizer.padding_idx, + attention_config=model_config.attention_config, +) +token_embedder = TokenEmbedder( + token_embedder_config=token_embedder_config, +) +categorical_var_net = CategoricalVariableNet( + categorical_vocabulary_sizes=value_encoder.vocabulary_sizes, + categorical_embedding_dims=model_config.categorical_embedding_dims, + text_embedding_dim=model_config.embedding_dim, +) + +all_sentence_embedders = [] +all_classification_heads = [] + +for num_classes in value_encoder.num_classes: # ty:ignore[not-iterable] + sentence_embedder_config = SentenceEmbedderConfig( + label_attention_config=LabelAttentionConfig( + n_head=model_config.n_heads_label_attention, + num_classes=num_classes, + embedding_dim=model_config.embedding_dim, + ), + aggregation_method=model_config.aggregation_method, + ) + + sentence_embedder = SentenceEmbedder(sentence_embedder_config=sentence_embedder_config) + all_sentence_embedders.append(sentence_embedder) + + classif_head_input_dim = model_config.embedding_dim + if categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: + classif_head_input_dim += categorical_var_net.output_dim + + # because we use LabelAttention, the sentence embedder outputs a (num_classes, embedding_dim) tensor, and the classification head should output a single logit per class (i.e. num_classes=1) + classification_head = ClassificationHead(input_dim=classif_head_input_dim, num_classes=1) + all_classification_heads.append(classification_head) + + +class MultiLevelTextClassificationModel(nn.Module): + def __init__( + self, + token_embedder: TokenEmbedder, + sentence_embedders: list[SentenceEmbedder], + classification_heads: list[ClassificationHead], + categorical_variable_net: CategoricalVariableNet, + ): + super().__init__() + self.token_embedder = token_embedder + self.sentence_embedders = sentence_embedders + self.classification_heads = classification_heads + self.categorical_variable_net = categorical_variable_net + self.num_classes: list[int] = [ + self.sentence_embedders[i].label_attention_config.num_classes + if self.sentence_embedders[i].label_attention_config is not None + else self.classification_heads[i].num_classes + for i in range(len(self.sentence_embedders)) + ] + + def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): + token_embed_output = self.token_embedder(input_ids, attention_mask) + x_token = token_embed_output["token_embeddings"] + x_cat = self.categorical_variable_net(categorical_vars) # (bs, cat_emb_dim) + + print(f"Token embeddings shape: {x_token.shape}") + print( + f"x_cat shape: {x_cat.shape}" + ) # Debugging line to check shape of categorical variable embeddings + outputs = [] + for sentence_embedder, classification_head in zip( + self.sentence_embedders, self.classification_heads + ): + if sentence_embedder.label_attention_config is not None: + num_classes = sentence_embedder.label_attention_config.num_classes + x_cat_level = x_cat.unsqueeze(1).expand(-1, num_classes, -1) + else: + x_cat_level = x_cat + + print( + f"x_cat_level shape: {x_cat_level.shape}" + ) # Debugging line to check shape of categorical variable embeddings after expansion + sentence_embedding = sentence_embedder( + token_embeddings=x_token, attention_mask=attention_mask + )[ + "sentence_embedding" + ] # (bs, embedding_dim) or (bs, num_classes, embedding_dim) if label attention + + print( + f"Sentence embedding shape: {sentence_embedding.shape}" + ) # Debugging line to check shape of sentence embeddings + if ( + self.categorical_variable_net.forward_type + == CategoricalForwardType.AVERAGE_AND_CONCAT + or self.categorical_variable_net.forward_type + == CategoricalForwardType.CONCATENATE_ALL + ): + x_combined = torch.cat((sentence_embedding, x_cat_level), dim=-1) + else: + assert ( + self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT + ) + + x_combined = sentence_embedding + x_cat_level + + print( + f"x_combined shape: {x_combined.shape}" + ) # Debugging line to check shape of combined features before classification head + output = classification_head(x_combined).squeeze(-1) + outputs.append(output) + print( + f"Output shape for current level: {output.shape}" + ) # Debugging line to check shape of output logits for current level + + return outputs + + +class MultiLevelCrossEntropyLoss(nn.Module): + def __init__(self, num_classes: Optional[list[int]] = None): + super().__init__() + self.num_classes = num_classes + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, outputs, labels): + total_loss = 0 + for idx, output in enumerate(outputs): + label = labels[:, idx] # (batch_size,) + if self.num_classes is not None: + total_loss += self.loss_fn(output.squeeze(), label) * self.num_classes[idx] + else: + total_loss += self.loss_fn(output.squeeze(), label) + + if self.num_classes is not None: + total_weight = sum(self.num_classes) + else: + total_weight = len(outputs) + + return total_loss / total_weight # average loss across levels + + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=all_sentence_embedders, + classification_heads=all_classification_heads, + categorical_variable_net=categorical_var_net, +) + +module = TextClassificationModule( + model=model, + loss=MultiLevelCrossEntropyLoss(), + optimizer=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + scheduler=None, + scheduler_params=None, +) + +print(model.num_classes) + +batch = next(iter(train_dataloader)) +print(batch["labels"].shape) +outputs = model(**batch) +print(f"Outputs shapes: {[output.shape for output in outputs]}") + + +ttc = torchTextClassifiers.from_model( + tokenizer=tokenizer, pytorch_model=model, value_encoder=value_encoder +) + +training_config = TrainingConfig( + num_epochs=1, + batch_size=6, + lr=1e-3, + raw_categorical_inputs=True, + loss=MultiLevelCrossEntropyLoss(num_classes=value_encoder.num_classes), +) + +ttc.train( + X_train=df[["text", "category", "color"]].values, + y_train=df[["label_level_1", "label_level_2", "label_level_3"]].values, + training_config=training_config, +) + + +print( + ttc.predict( + X_test=df[["text", "category", "color"]].values, + ) +) diff --git a/examples/simple_explainability_example.py b/examples/simple_explainability_example.py index 20cc1d5..1febde1 100644 --- a/examples/simple_explainability_example.py +++ b/examples/simple_explainability_example.py @@ -14,7 +14,6 @@ from torchTextClassifiers.tokenizers import WordPieceTokenizer from torchTextClassifiers.utilities.plot_explainability import ( map_attributions_to_char, - map_attributions_to_word, ) @@ -23,8 +22,8 @@ def main(): SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -36,64 +35,94 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸ” Simple Explainability Example") # Enhanced training data with more diverse examples - X_train = np.array([ - # Positive examples - "I love this product", - "Great quality and excellent service", - "Amazing design and fantastic performance", - "Outstanding value for money", - "Excellent customer support team", - "Love the innovative features", - "Perfect solution for my needs", - "Highly recommend this item", - "Superb build quality", - "Wonderful experience overall", - "Great value and fast delivery", - "Excellent product with amazing results", - "Love this fantastic design", - "Perfect quality and great price", - "Amazing customer service experience", - - # Negative examples - "This is terrible quality", - "Poor design and cheap materials", - "Awful experience with this product", - "Terrible customer service response", - "Completely disappointing purchase", - "Poor quality and overpriced item", - "Awful build quality issues", - "Terrible value for money", - "Disappointing performance results", - "Poor service and bad experience", - "Awful design and cheap feel", - "Terrible product with many issues", - "Disappointing quality and poor value", - "Bad experience with customer support", - "Poor construction and awful materials" - ]) - - y_train = np.array([ - # Positive labels (1) - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - # Negative labels (0) - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ]) - - X_val = np.array([ - "Good product with decent quality", - "Bad quality and poor service", - "Excellent value and great design", - "Terrible experience and awful quality" - ]) + X_train = np.array( + [ + # Positive examples + "I love this product", + "Great quality and excellent service", + "Amazing design and fantastic performance", + "Outstanding value for money", + "Excellent customer support team", + "Love the innovative features", + "Perfect solution for my needs", + "Highly recommend this item", + "Superb build quality", + "Wonderful experience overall", + "Great value and fast delivery", + "Excellent product with amazing results", + "Love this fantastic design", + "Perfect quality and great price", + "Amazing customer service experience", + # Negative examples + "This is terrible quality", + "Poor design and cheap materials", + "Awful experience with this product", + "Terrible customer service response", + "Completely disappointing purchase", + "Poor quality and overpriced item", + "Awful build quality issues", + "Terrible value for money", + "Disappointing performance results", + "Poor service and bad experience", + "Awful design and cheap feel", + "Terrible product with many issues", + "Disappointing quality and poor value", + "Bad experience with customer support", + "Poor construction and awful materials", + ] + ) + + y_train = np.array( + [ + # Positive labels (1) + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + # Negative labels (0) + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ) + + X_val = np.array( + [ + "Good product with decent quality", + "Bad quality and poor service", + "Excellent value and great design", + "Terrible experience and awful quality", + ] + ) y_val = np.array([1, 0, 1, 0]) # Create and train tokenizer @@ -105,17 +134,11 @@ def main(): # Create model configuration print("\nšŸ”§ Creating model configuration...") - model_config = ModelConfig( - embedding_dim=50, - num_classes=2 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=2) # Create classifier print("\nšŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print("āœ… Classifier created successfully!") # Train the model @@ -126,28 +149,25 @@ def main(): lr=1e-3, patience_early_stopping=5, num_workers=0, - trainer_params={'deterministic': True} + trainer_params={"deterministic": True}, ) classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) print("āœ… Training completed!") - + # Test examples with different sentiments test_texts = [ "This product is amazing!", "Poor quality and terrible service", "Great value for money", "Completely disappointing and awful experience", - "Love this excellent design" + "Love this excellent design", ] - + print(f"\nšŸ” Testing explainability on {len(test_texts)} examples:") print("=" * 60) - + for i, test_text in enumerate(test_texts, 1): print(f"\nšŸ“ Example {i}:") print(f"Text: '{test_text}'") @@ -159,7 +179,9 @@ def main(): # Extract prediction prediction = result["prediction"][0][0].item() confidence = result["confidence"][0][0].item() - print(f"Prediction: {'Positive' if prediction == 1 else 'Negative'} (confidence: {confidence:.4f})") + print( + f"Prediction: {'Positive' if prediction == 1 else 'Negative'} (confidence: {confidence:.4f})" + ) # Extract attributions and mapping info attributions = result["attributions"][0][0] # shape: (seq_len,) @@ -170,7 +192,7 @@ def main(): char_attributions = map_attributions_to_char( attributions.unsqueeze(0), # Add batch dimension: (1, seq_len) offset_mapping, - test_text + test_text, )[0] # Get first result print("\nšŸ“Š Character-Level Contribution Visualization:") @@ -187,7 +209,7 @@ def main(): for word in words: word_len = len(word) # Get attributions for this word - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: avg_attr = sum(word_attrs) / len(word_attrs) bar_length = int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 @@ -202,7 +224,7 @@ def main(): word_scores = [] for word in words: word_len = len(word) - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: word_scores.append((word, sum(word_attrs) / len(word_attrs))) char_idx += word_len + 1 @@ -214,33 +236,34 @@ def main(): except Exception as e: print(f"āš ļø Explainability failed: {e}") import traceback + traceback.print_exc() - + # Analysis completed for this example print(f"āœ… Analysis completed for example {i}") - + print(f"\nšŸŽ‰ Explainability analysis completed for {len(test_texts)} examples!") - + # Interactive section for user input (only if --interactive flag is provided) if "--interactive" in sys.argv: - print("\n" + "="*60) + print("\n" + "=" * 60) print("šŸŽÆ Interactive Explainability Mode") - print("="*60) + print("=" * 60) print("Enter your own text to see predictions and explanations!") print("Type 'quit' or 'exit' to end the session.\n") - + while True: try: user_text = input("šŸ’¬ Enter text: ").strip() - - if user_text.lower() in ['quit', 'exit', 'q']: + + if user_text.lower() in ["quit", "exit", "q"]: print("šŸ‘‹ Thanks for using the explainability tool!") break - + if not user_text: print("āš ļø Please enter some text.") continue - + print(f"\nšŸ” Analyzing: '{user_text}'") # Get prediction with explainability @@ -262,7 +285,7 @@ def main(): char_attributions = map_attributions_to_char( attributions.unsqueeze(0), # Add batch dimension: (1, seq_len) offset_mapping, - user_text + user_text, )[0] # Get first result print("\nšŸ“Š Character-Level Contribution Visualization:") @@ -279,10 +302,12 @@ def main(): for word in words: word_len = len(word) # Get attributions for this word - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: avg_attr = sum(word_attrs) / len(word_attrs) - bar_length = int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 + bar_length = ( + int((avg_attr / max_attr) * bar_width) if max_attr > 0 else 0 + ) bar = "ā–ˆ" * bar_length print(f"{word:>15} | {bar:<40} {avg_attr:.4f}") char_idx += word_len + 1 # +1 for space @@ -294,23 +319,26 @@ def main(): word_scores = [] for word in words: word_len = len(word) - word_attrs = char_attributions[char_idx:char_idx + word_len] + word_attrs = char_attributions[char_idx : char_idx + word_len] if len(word_attrs) > 0: word_scores.append((word, sum(word_attrs) / len(word_attrs))) char_idx += word_len + 1 if word_scores: top_word, top_score = max(word_scores, key=lambda x: x[1]) - print(f"šŸ’” Most influential word: '{top_word}' (avg score: {top_score:.4f})") + print( + f"šŸ’” Most influential word: '{top_word}' (avg score: {top_score:.4f})" + ) except Exception as e: print(f"āš ļø Explainability failed: {e}") print("šŸ” Prediction available, but detailed explanation unavailable.") import traceback + traceback.print_exc() - - print("\n" + "-"*50) - + + print("\n" + "-" * 50) + except KeyboardInterrupt: print("\nšŸ‘‹ Session interrupted. Goodbye!") break @@ -318,9 +346,11 @@ def main(): print(f"āš ļø Error: {e}") continue else: - print("\nšŸ’” Tip: Use --interactive flag to enter interactive mode for custom text analysis!") + print( + "\nšŸ’” Tip: Use --interactive flag to enter interactive mode for custom text analysis!" + ) print(" Example: uv run python examples/simple_explainability_example.py --interactive") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/using_additional_features.py b/examples/using_additional_features.py index 90ab9b2..c740510 100644 --- a/examples/using_additional_features.py +++ b/examples/using_additional_features.py @@ -7,7 +7,6 @@ """ import os -import random import time import warnings @@ -20,9 +19,11 @@ from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers from torchTextClassifiers.tokenizers import WordPieceTokenizer + # Note: SimpleTextWrapper is not available in the current version # from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextConfig, SimpleTextWrapper + def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1): # Get unique labels and their frequencies unique_labels, label_counts = np.unique(y, return_counts=True) @@ -54,10 +55,10 @@ def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1): def merge_cat(cat): - if cat in ['World', 'Top News', 'Europe', 'Italia', 'U.S.', 'Top Stories']: - return 'World News' - if cat in ['Sci/Tech', 'Software and Developement', 'Toons', 'Health', 'Music Feeds']: - return 'Tech and Stuff' + if cat in ["World", "Top News", "Europe", "Italia", "U.S.", "Top Stories"]: + return "World News" + if cat in ["Sci/Tech", "Software and Developement", "Toons", "Health", "Music Feeds"]: + return "Tech and Stuff" return cat @@ -67,10 +68,9 @@ def load_and_prepare_data(): df = pd.read_parquet("https://minio.lab.sspcloud.fr/h4njlg/public/ag_news_full_1M.parquet") df = df.sample(10000, random_state=42) # Smaller sample to avoid disk space issues print(f"āœ… Loaded {len(df)} samples from AG NEWS dataset") - - df['category_final'] = df['category'].apply(lambda x: merge_cat(x)) - df['title_headline'] = df['title'] + '\n####\n' + df['description'] + df["category_final"] = df["category"].apply(lambda x: merge_cat(x)) + df["title_headline"] = df["title"] + "\n####\n" + df["description"] # categorical_features = None # text_feature = "title_headline" @@ -79,29 +79,32 @@ def load_and_prepare_data(): source_encoder = LabelEncoder() df["title_headline_processed"] = df["title_headline"] - df["source_encoded"] = source_encoder.fit_transform(df['source']) + df["source_encoded"] = source_encoder.fit_transform(df["source"]) - X_text_only = df[['title_headline_processed']].values - X_mixed = df[['title_headline_processed', "source_encoded"]].values - y = df['category_final'].values + X_text_only = df[["title_headline_processed"]].values + X_mixed = df[["title_headline_processed", "source_encoded"]].values + y = df["category_final"].values encoder = LabelEncoder() y = encoder.fit_transform(y) return X_text_only, X_mixed, y, encoder - + def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple=False): """Train and evaluate a FastText model""" print(f"\nšŸŽÆ Training {model_name}...") - - + # Split data twice: first for train/temp, then temp into validation/test X_train, X_temp, y_train, y_temp = stratified_split_rare_labels( - X, y, test_size=0.1 # 40% for validation + test + X, + y, + test_size=0.1, # 40% for validation + test ) X_val, X_test, y_val, y_test = stratified_split_rare_labels( - X_temp, y_temp, test_size=0.5 # Split temp 50/50 into validation and test + X_temp, + y_temp, + test_size=0.5, # Split temp 50/50 into validation and test ) - + # Note: SimpleTextWrapper is not available in the current version # The use_simple branch has been disabled if use_simple: @@ -133,22 +136,16 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple embedding_dim=50, categorical_vocabulary_sizes=vocab_sizes, categorical_embedding_dims=10, - num_classes=5 + num_classes=5, ) print(f" Categorical vocabulary sizes: {vocab_sizes}") else: # For text-only model - model_config = ModelConfig( - embedding_dim=50, - num_classes=5 - ) + model_config = ModelConfig(embedding_dim=50, num_classes=5) # Create classifier print(" šŸ”Ø Creating classifier...") - classifier = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config - ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) print(" āœ… Classifier created successfully!") # Training configuration @@ -158,10 +155,7 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple lr=0.001, patience_early_stopping=3, num_workers=0, - trainer_params={ - 'enable_progress_bar': True, - 'deterministic': True - } + trainer_params={"enable_progress_bar": True, "deterministic": True}, ) # Create and build model @@ -170,10 +164,7 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple # Train model print(" šŸŽÆ Training model...") classifier.train( - X_train, y_train, - training_config=training_config, - X_val=X_val, y_val=y_val, - verbose=True + X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True ) training_time = time.time() - start_time @@ -201,26 +192,24 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple print(f" āš ļø Validation failed: {e}") test_accuracy = 0.0 predictions = np.zeros(len(y_test)) - + return { - 'model_name': model_name, - 'test_accuracy': test_accuracy, - 'training_time': training_time, - 'predictions': predictions, - 'y_test': y_test, - 'classifier': classifier + "model_name": model_name, + "test_accuracy": test_accuracy, + "training_time": training_time, + "predictions": predictions, + "y_test": y_test, + "classifier": classifier, } - - def main(): # Set seed for reproducibility SEED = 42 # Set environment variables for full reproducibility - os.environ['PYTHONHASHSEED'] = str(SEED) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + os.environ["PYTHONHASHSEED"] = str(SEED) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Use PyTorch Lightning's seed_everything for comprehensive seeding seed_everything(SEED, workers=True) @@ -232,10 +221,7 @@ def main(): # Suppress PyTorch Lightning warnings for cleaner output warnings.filterwarnings( - 'ignore', - message='.*', - category=UserWarning, - module='pytorch_lightning' + "ignore", message=".*", category=UserWarning, module="pytorch_lightning" ) print("šŸ”€ Classifier: Categorical Features Comparison") @@ -245,11 +231,11 @@ def main(): # Load and prepare data (same as notebook) X_text_only, X_mixed, y, encoder = load_and_prepare_data() - + # Train models - print(f"\nšŸš€ Training Models:") + print("\nšŸš€ Training Models:") print("-" * 40) - + # Text-only model results_text_only = train_and_evaluate_model( X_text_only, y, "Text-Only Classifier", use_categorical=False @@ -264,21 +250,26 @@ def main(): # results_tfidf = train_and_evaluate_model(X_text_only, y, "TF-IDF classifier", use_categorical=False, use_simple=True) # Compare results - print(f"\nšŸ“Š Results Comparison:") + print("\nšŸ“Š Results Comparison:") print("=" * 50) print(f"{'Model':<25}{'Test Acc':<11} {'Time (s)':<10}") print("-" * 50) - print(f"{'Text-Only':<25} " - f"{results_text_only['test_accuracy']:<11.3f} {results_text_only['training_time']:<10.1f}") - print(f"{'Mixed Features':<25} " - f"{results_mixed['test_accuracy']:<11.3f} {results_mixed['training_time']:<10.1f}") + print( + f"{'Text-Only':<25} " + f"{results_text_only['test_accuracy']:<11.3f} {results_text_only['training_time']:<10.1f}" + ) + print( + f"{'Mixed Features':<25} " + f"{results_mixed['test_accuracy']:<11.3f} {results_mixed['training_time']:<10.1f}" + ) # Calculate improvements - acc_improvement = results_mixed['test_accuracy'] - results_text_only['test_accuracy'] - time_overhead = results_mixed['training_time'] - results_text_only['training_time'] - + acc_improvement = results_mixed["test_accuracy"] - results_text_only["test_accuracy"] + time_overhead = results_mixed["training_time"] - results_text_only["training_time"] + print("-" * 50) print(f"Test Accuracy Improvement: {acc_improvement:+.3f}") print(f"Training Time Overhead: {time_overhead:+.1f}s") + if __name__ == "__main__": - main() \ No newline at end of file + main() From 0fab098c5174a601fa0baa77dcd03e849632b614 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Fri, 24 Apr 2026 11:48:28 +0000 Subject: [PATCH 3/6] feat!(multilevel): support for several classifications simultaneously - basically, support for "list" labels and list num classes --- torchTextClassifiers/dataset/dataset.py | 4 +- torchTextClassifiers/model/lightning.py | 89 +- torchTextClassifiers/torchTextClassifiers.py | 1795 +++++++++-------- .../value_encoder/value_encoder.py | 111 +- 4 files changed, 1074 insertions(+), 925 deletions(-) diff --git a/torchTextClassifiers/dataset/dataset.py b/torchTextClassifiers/dataset/dataset.py index 6475567..064d78b 100644 --- a/torchTextClassifiers/dataset/dataset.py +++ b/torchTextClassifiers/dataset/dataset.py @@ -52,6 +52,8 @@ def __init__( logger.warning( "ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong." ) + elif not self.ragged_multilabel and self.labels is not None: + self.labels = torch.tensor(labels, dtype=torch.long) def __len__(self): return len(self.texts) @@ -100,7 +102,7 @@ def collate_fn(self, batch): labels_tensor[rows, cols] = 1 else: - labels_tensor = torch.tensor(labels) + labels_tensor = torch.stack(list(labels)) else: labels_tensor = None diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index 8726f20..886b347 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -36,11 +36,23 @@ def __init__( scheduler_interval: Scheduler interval. """ super().__init__() - self.save_hyperparameters(ignore=["model"]) + self.save_hyperparameters(ignore=["model", "loss"]) self.model = model self.loss = loss - self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes) + + if not hasattr(self.model, "num_classes") or self.model.num_classes is None: + raise ValueError("Model must have num_classes attribute for accuracy calculation.") + + if isinstance(self.model.num_classes, list): + self.accuracy_fn = torch.nn.ModuleList( + [Accuracy(task="multiclass", num_classes=n) for n in self.model.num_classes] + ) + self.multilevel_accuracy = True + else: + self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes) + self.multilevel_accuracy = False + self.optimizer = optimizer self.optimizer_params = optimizer_params self.scheduler = scheduler @@ -62,71 +74,42 @@ def forward(self, batch) -> torch.Tensor: categorical_vars=batch.get("categorical_vars", None), ) - def training_step(self, batch, batch_idx: int) -> torch.Tensor: - """ - Training step. - - Args: - batch (List[torch.LongTensor]): Training batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - + def step(self, batch) -> tuple[torch.Tensor, torch.Tensor | list[torch.Tensor]]: targets = batch["labels"] - outputs = self.forward(batch) - if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): targets = targets.float() - loss = self.loss(outputs, targets) - self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) - accuracy = self.accuracy_fn(outputs, targets) - self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True) + if self.multilevel_accuracy: + accuracy = [ + fn(out, targets[:, i]) for i, (fn, out) in enumerate(zip(self.accuracy_fn, outputs)) + ] + else: + accuracy = self.accuracy_fn(outputs, targets) + return loss, accuracy - torch.cuda.empty_cache() + def _log_accuracy(self, accuracy: torch.Tensor | list[torch.Tensor], prefix: str, **kwargs): + if isinstance(accuracy, list): + for i, acc in enumerate(accuracy): + self.log(f"{prefix}_accuracy_level_{i}", acc, **kwargs) + else: + self.log(f"{prefix}_accuracy", accuracy, **kwargs) + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + loss, accuracy = self.step(batch) + self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) + self._log_accuracy(accuracy, "train", on_epoch=True, on_step=False, prog_bar=True) + torch.cuda.empty_cache() return loss def validation_step(self, batch, batch_idx: int): - """ - Validation step. - - Args: - batch (List[torch.LongTensor]): Validation batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - targets = batch["labels"] - - outputs = self.forward(batch) - - loss = self.loss(outputs, targets) + loss, accuracy = self.step(batch) self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True) - - accuracy = self.accuracy_fn(outputs, targets) - self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True) + self._log_accuracy(accuracy, "val", on_epoch=True, on_step=False, prog_bar=True) return loss def test_step(self, batch, batch_idx: int): - """ - Test step. - - Args: - batch (List[torch.LongTensor]): Test batch. - batch_idx (int): Batch index. - - Returns (torch.Tensor): Loss tensor. - """ - targets = batch["labels"] - - outputs = self.forward(batch) - loss = self.loss(outputs, targets) - - accuracy = self.accuracy_fn(outputs, targets) - + loss, accuracy = self.step(batch) return loss, accuracy def predict_step(self, batch, batch_idx: int = 0, dataloader_idx: int = 0): diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index b5aeb10..0541c99 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -1,848 +1,947 @@ -import logging -import pickle -import time -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -try: - from captum.attr import LayerIntegratedGradients - - HAS_CAPTUM = True -except ImportError: - HAS_CAPTUM = False - - -import numpy as np -import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import ( - EarlyStopping, - LearningRateMonitor, - ModelCheckpoint, -) - -from torchTextClassifiers.dataset import TextClassificationDataset -from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule -from torchTextClassifiers.model.components import ( - AttentionConfig, - CategoricalForwardType, - CategoricalVariableNet, - ClassificationHead, - LabelAttentionConfig, - SentenceEmbedder, - SentenceEmbedderConfig, - TokenEmbedder, - TokenEmbedderConfig, -) -from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput -from torchTextClassifiers.value_encoder import ValueEncoder - -logger = logging.getLogger(__name__) - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler()], -) - - -@dataclass -class ModelConfig: - """Base configuration class for text classifiers.""" - - embedding_dim: int - num_classes: Optional[int] = None - categorical_vocabulary_sizes: Optional[List[int]] = None - categorical_embedding_dims: Optional[Union[List[int], int]] = None - attention_config: Optional[AttentionConfig] = None - n_heads_label_attention: Optional[int] = None - aggregation_method: Optional[str] = "mean" - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig": - return cls(**data) - - -@dataclass -class TrainingConfig: - num_epochs: int - batch_size: int - lr: float - raw_categorical_inputs: Optional[bool] = True - raw_labels: Optional[bool] = True - loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss()) - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam - scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None - accelerator: str = "auto" - num_workers: int = 12 - patience_early_stopping: int = 3 - dataloader_params: Optional[dict] = None - trainer_params: Optional[dict] = None - optimizer_params: Optional[dict] = None - scheduler_params: Optional[dict] = None - save_path: Optional[str] = "my_ttc" - - def to_dict(self) -> Dict[str, Any]: - data = asdict(self) - # Serialize loss and scheduler as their class names - data["loss"] = self.loss.__class__.__name__ - if self.scheduler is not None: - data["scheduler"] = self.scheduler.__name__ - return data - - -class torchTextClassifiers: - """Generic text classifier framework supporting multiple architectures. - - Given a tokenizer and model configuration, this class initializes: - - Text embedding layer (if needed) - - Categorical variable embedding network (if categorical variables are provided) - - Classification head - The resulting model can be trained using PyTorch Lightning and used for predictions. - - """ - - def __init__( - self, - tokenizer: BaseTokenizer, - model_config: ModelConfig, - ragged_multilabel: bool = False, - value_encoder: Optional[ValueEncoder] = None, - ): - """Initialize the torchTextClassifiers instance. - - Args: - tokenizer: A tokenizer instance for text preprocessing - model_config: Configuration parameters for the text classification model - ragged_multilabel: Whether to use ragged multilabel classification - value_encoder: Optional ValueEncoder for encoding - raw string (or mixed) categorical values to integers. Build it - beforehand from DictEncoder or sklearn LabelEncoder instances and - pass it here. If None, categorical columns in X must already be - integer-encoded. - - Example: - >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers - >>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder - >>> # Build one DictEncoder per categorical feature - >>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))}) - ... for i in range(X_categorical.shape[1])} - >>> encoder = ValueEncoder(encoders) - >>> model_config = ModelConfig( - ... embedding_dim=10, - ... categorical_vocabulary_sizes=encoder.vocabulary_sizes, - ... categorical_embedding_dims=[10, 5], - ... num_classes=10, - ... ) - >>> ttc = torchTextClassifiers( - ... tokenizer=tokenizer, - ... model_config=model_config, - ... value_encoder=encoder, - ... ) - """ - - self.model_config = model_config - self.tokenizer = tokenizer - self.ragged_multilabel = ragged_multilabel - self.value_encoder: ValueEncoder | None = value_encoder - - if hasattr(self.tokenizer, "trained"): - if not self.tokenizer.trained: - raise RuntimeError( - f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier." - ) - - self.vocab_size = tokenizer.vocab_size - self.embedding_dim = model_config.embedding_dim - - if self.value_encoder is not None: - if (model_config.num_classes != self.value_encoder.num_classes) or ( - model_config.categorical_vocabulary_sizes != self.value_encoder.vocabulary_sizes - ): - logger.info( - "Overriding model_config num_classes and/or categorical_vocabulary_sizes with values from value_encoder." - ) - self.categorical_vocabulary_sizes = self.value_encoder.vocabulary_sizes - self.num_classes = self.value_encoder.num_classes - else: - self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes - if model_config.num_classes is None: - raise ValueError( - "num_classes must be specified in the model configuration if no value_encoder is provided." - ) - self.num_classes = model_config.num_classes - - self.enable_label_attention = model_config.n_heads_label_attention is not None - - if self.tokenizer.output_vectorized: - self.token_embedder = None - logger.info( - "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization." - ) - self.embedding_dim = self.tokenizer.output_dim - else: - token_embedder_config = TokenEmbedderConfig( - vocab_size=self.vocab_size, - embedding_dim=self.embedding_dim, - padding_idx=tokenizer.padding_idx, - attention_config=model_config.attention_config, - ) - sentence_embedder_config = SentenceEmbedderConfig( - label_attention_config=LabelAttentionConfig( - n_head=model_config.n_heads_label_attention, - num_classes=model_config.num_classes, - embedding_dim=self.embedding_dim, - ) - if self.enable_label_attention - else None, - aggregation_method=model_config.aggregation_method, - ) - self.token_embedder = TokenEmbedder( - token_embedder_config=token_embedder_config, - ) - self.sentence_embedder = SentenceEmbedder( - sentence_embedder_config=sentence_embedder_config - ) - - classif_head_input_dim = self.embedding_dim - if self.categorical_vocabulary_sizes: - self.categorical_var_net = CategoricalVariableNet( - categorical_vocabulary_sizes=self.categorical_vocabulary_sizes, - categorical_embedding_dims=model_config.categorical_embedding_dims, - text_embedding_dim=self.embedding_dim, - ) - - if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: - classif_head_input_dim += self.categorical_var_net.output_dim - - else: - self.categorical_var_net = None - - self.classification_head = ClassificationHead( - input_dim=classif_head_input_dim, - num_classes=1 - if self.enable_label_attention - else self.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) - ) - - self.pytorch_model = TextClassificationModel( - token_embedder=self.token_embedder, - sentence_embedder=self.sentence_embedder, - categorical_variable_net=self.categorical_var_net, - classification_head=self.classification_head, - ) - - def train( - self, - X_train: np.ndarray, - y_train: np.ndarray, - training_config: TrainingConfig, - X_val: Optional[np.ndarray] = None, - y_val: Optional[np.ndarray] = None, - verbose: bool = False, - ) -> None: - """Train the classifier using PyTorch Lightning. - - This method handles the complete training process including: - - Data validation and preprocessing - - Dataset and DataLoader creation - - PyTorch Lightning trainer setup with callbacks - - Model training with early stopping - - Best model loading after training - - Note on Checkpoints: - After training, the best model checkpoint is automatically loaded. - This checkpoint contains the full training state (model weights, - optimizer, and scheduler state). Loading uses weights_only=False - as the checkpoint is self-generated and trusted. - - Args: - X_train: Training input data - y_train: Training labels - X_val: Validation input data - y_val: Validation labels - training_config: Configuration parameters for training - verbose: Whether to print training progress information - - - Example: - - >>> training_config = TrainingConfig( - ... lr=1e-3, - ... batch_size=4, - ... num_epochs=1, - ... ) - >>> ttc.train( - ... X_train=X, - ... y_train=Y, - ... X_val=X, - ... y_val=Y, - ... training_config=training_config, - ... ) - """ - - # Input validation - X_train, y_train = self._check_XY( - X_train, y_train, training_config.raw_categorical_inputs, training_config.raw_labels - ) - - if X_val is not None: - assert y_val is not None, "y_val must be provided if X_val is provided." - if y_val is not None: - assert X_val is not None, "X_val must be provided if y_val is provided." - - X_val: Optional[Dict[str, Any]] = None - if X_val is not None and y_val is not None: - X_val, y_val = self._check_XY(X_val, y_val) - - if ( - X_train["categorical_variables"] is not None - and X_val is not None - and X_val["categorical_variables"] is not None - ): - assert ( - X_train["categorical_variables"].ndim > 1 - and X_train["categorical_variables"].shape[1] - == X_val["categorical_variables"].shape[1] - or X_val["categorical_variables"].ndim == 1 - ), "X_train and X_val must have the same number of columns." - - if verbose: - logger.info("Starting training process...") - - if training_config.accelerator == "auto": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device(training_config.accelerator) - - self.device = device - - optimizer_params = {"lr": training_config.lr} - if training_config.optimizer_params is not None: - optimizer_params.update(training_config.optimizer_params) - - if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel: - logger.warning( - "āš ļø You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks." - ) - - self.lightning_module = TextClassificationModule( - model=self.pytorch_model, - loss=training_config.loss, - optimizer=training_config.optimizer, - optimizer_params=optimizer_params, - scheduler=training_config.scheduler, - scheduler_params=training_config.scheduler_params - if training_config.scheduler_params - else {}, - scheduler_interval="epoch", - ) - - self.pytorch_model.to(self.device) - - if verbose: - logger.info(f"Running on: {device}") - - train_dataset = TextClassificationDataset( - texts=X_train["text"], - categorical_variables=X_train["categorical_variables"], # None if no cat vars - tokenizer=self.tokenizer, - labels=y_train.tolist(), - ragged_multilabel=self.ragged_multilabel, - ) - train_dataloader = train_dataset.create_dataloader( - batch_size=training_config.batch_size, - num_workers=training_config.num_workers, - shuffle=True, - **training_config.dataloader_params if training_config.dataloader_params else {}, - ) - - if X_val is not None and y_val is not None: - val_dataset = TextClassificationDataset( - texts=X_val["text"], - categorical_variables=X_val["categorical_variables"], # None if no cat vars - tokenizer=self.tokenizer, - labels=y_val, - ragged_multilabel=self.ragged_multilabel, - ) - val_dataloader = val_dataset.create_dataloader( - batch_size=training_config.batch_size, - num_workers=training_config.num_workers, - shuffle=False, - **training_config.dataloader_params if training_config.dataloader_params else {}, - ) - else: - val_dataloader = None - - # Setup trainer - callbacks = [ - ModelCheckpoint( - monitor="val_loss" if val_dataloader is not None else "train_loss", - save_top_k=1, - save_last=False, - mode="min", - ), - EarlyStopping( - monitor="val_loss" if val_dataloader is not None else "train_loss", - patience=training_config.patience_early_stopping, - mode="min", - ), - LearningRateMonitor(logging_interval="step"), - ] - - trainer_params = { - "accelerator": training_config.accelerator, - "callbacks": callbacks, - "max_epochs": training_config.num_epochs, - "num_sanity_val_steps": 2, - "strategy": "auto", - "log_every_n_steps": 1, - "enable_progress_bar": True, - } - - if training_config.trainer_params is not None: - trainer_params.update(training_config.trainer_params) - - trainer = pl.Trainer(**trainer_params) - - torch.cuda.empty_cache() - torch.set_float32_matmul_precision("medium") - - if verbose: - logger.info("Launching training...") - start = time.time() - - trainer.fit(self.lightning_module, train_dataloader, val_dataloader) - - if verbose: - end = time.time() - logger.info(f"Training completed in {end - start:.2f} seconds.") - - best_model_path = trainer.checkpoint_callback.best_model_path - self.checkpoint_path = best_model_path - - self.lightning_module = TextClassificationModule.load_from_checkpoint( - best_model_path, - model=self.pytorch_model, - loss=training_config.loss, - weights_only=False, # Required: checkpoint contains optimizer/scheduler state - ) - - self.pytorch_model = self.lightning_module.model.to(self.device) - - self.save_path = training_config.save_path - self.save(self.save_path) - - self.lightning_module.eval() - - def _check_XY( - self, X: np.ndarray, Y: np.ndarray, raw_categorical_inputs, raw_labels - ) -> Tuple[Dict[str, Any], np.ndarray]: - X_checked = self._check_X(X, raw_categorical_inputs) - Y_checked = self._check_Y(Y, raw_labels) - - if X_checked["text"].shape[0] != len(Y_checked): - raise ValueError("X_train and y_train must have the same number of observations.") - - return X_checked, Y_checked - - @staticmethod - def _check_text_col(X): - assert isinstance( - X, np.ndarray - ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables." - - try: - if X.ndim > 1: - text = X[:, 0].astype(str) - else: - text = X[:].astype(str) - except ValueError: - logger.error("The first column of X must be castable in string format.") - - return text - - def _check_categorical_variables( - self, X: np.ndarray, raw_categorical_inputs: bool - ) -> np.ndarray: - """Validate and encode categorical variables from X. - - If a ``value_encoder`` was provided at initialization, raw string - or mixed values are encoded to integers via that encoder. Otherwise the - categorical columns must already be integer-encodable. - - Args: - X: Full input array whose first column is text and whose remaining - columns are categorical variables. - - Returns: - Integer-encoded categorical array of shape (N, n_cat_features). - - Raises: - ValueError: If the number of categorical features does not match the - model configuration, if values exceed vocabulary bounds, or if - values cannot be cast to integers and no encoder was provided. - """ - assert self.categorical_var_net is not None - - num_cat_vars = X.shape[1] - 1 if X.ndim > 1 else 0 - - if num_cat_vars != self.categorical_var_net.num_categorical_features: - raise ValueError( - f"X must have the same number of categorical variables as the number of " - f"embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})." - ) - - if raw_categorical_inputs: - if self.value_encoder is None: - raise ValueError( - "Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers." - ) - categorical_variables = self.value_encoder.transform(X[:, 1:]).astype(int) - else: - categorical_variables = X[:, 1:].astype(int) - - for j in range(num_cat_vars): - max_cat_value = categorical_variables[:, j].max() - if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]: - raise ValueError( - f"Categorical variable at index {j} has value {max_cat_value} which exceeds " - f"the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}." - ) - - return categorical_variables - - def _check_X(self, X: np.ndarray, raw_categorical_inputs: bool) -> Dict[str, Any]: - text = self._check_text_col(X) - - categorical_variables = None - if self.categorical_var_net is not None: - categorical_variables = self._check_categorical_variables(X, raw_categorical_inputs) - - return {"text": text, "categorical_variables": categorical_variables} - - def _check_Y(self, Y, raw_labels: bool) -> np.ndarray: - if self.ragged_multilabel: - assert isinstance( - Y, list - ), "Y must be a list of lists for ragged multilabel classification." - for row in Y: - assert isinstance(row, list), "Each element of Y must be a list of labels." - - return Y - - else: - assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." - assert ( - len(Y.shape) == 1 or len(Y.shape) == 2 - ), "Y must be a numpy array of shape (N,) or (N, num_labels)." - - if raw_labels: - if self.value_encoder is None: - raise ValueError( - "Raw label encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw labels to integers." - ) - Y = self.value_encoder.transform_labels(Y) - Y = Y.astype(int) - - if Y.max() >= self.num_classes or Y.min() < 0: - raise ValueError( - f"Y contains class labels outside the range [0, {self.num_classes - 1}]." - ) - - return Y - - def predict( - self, - X_test: np.ndarray, - raw_categorical_inputs: bool = True, - top_k=1, - explain_with_label_attention: bool = False, - explain_with_captum=False, - ): - """ - Args: - X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables - top_k (int): for each sentence, return the top_k most likely predictions (default: 1) - explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) - explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) - - Returns: A dictionary containing the following fields: - - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. - - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores. - - if explain is True: - - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. - """ - - explain = explain_with_label_attention or explain_with_captum - if explain: - return_offsets_mapping = True # to be passed to the tokenizer - return_word_ids = True - if self.pytorch_model.token_embedder is None: - raise RuntimeError( - "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." - ) - else: - if explain_with_captum: - if not HAS_CAPTUM: - raise ImportError( - "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." - ) - lig = LayerIntegratedGradients( - self.pytorch_model, self.pytorch_model.token_embedder.embedding_layer - ) # initialize a Captum layer gradient integrator - if explain_with_label_attention: - if not self.enable_label_attention: - raise RuntimeError( - "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." - ) - else: - return_offsets_mapping = False - return_word_ids = False - - X_test = self._check_X(X_test, raw_categorical_inputs) - text = X_test["text"] - categorical_variables = X_test["categorical_variables"] - - self.pytorch_model.eval().cpu() - - tokenize_output = self.tokenizer.tokenize( - text.tolist(), - return_offsets_mapping=return_offsets_mapping, - return_word_ids=return_word_ids, - ) - - if not isinstance(tokenize_output, TokenizerOutput): - raise TypeError( - f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method." - ) - - encoded_text = tokenize_output.input_ids # (batch_size, seq_len) - attention_mask = tokenize_output.attention_mask # (batch_size, seq_len) - - if categorical_variables is not None: - categorical_vars = torch.tensor( - categorical_variables, dtype=torch.float32 - ) # (batch_size, num_categorical_features) - else: - categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) - - model_output = self.pytorch_model( - encoded_text, - attention_mask, - categorical_vars, - return_label_attention_matrix=explain_with_label_attention, - ) # forward pass, contains the prediction scores (len(text), num_classes) - pred = ( - model_output["logits"] if explain_with_label_attention else model_output - ) # (batch_size, num_classes) - - label_attention_matrix = ( - model_output["label_attention_matrix"] if explain_with_label_attention else None - ) - - label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities - - label_scores_topk = torch.topk(label_scores, k=top_k, dim=1) - - integer_predictions = label_scores_topk.indices # integer class indices (needed for captum) - if self.value_encoder is not None: - predictions = self.value_encoder.inverse_transform_labels(integer_predictions.numpy()) - else: - predictions = integer_predictions - - confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores - - if explain: - if explain_with_captum: - # Captum explanations - captum_attributions = [] - for k in range(top_k): - attributions = lig.attribute( - (encoded_text, attention_mask, categorical_vars), - target=integer_predictions[:, k], - ) # (batch_size, seq_len) - attributions = attributions.sum(dim=-1) - captum_attributions.append(attributions.detach().cpu()) - - captum_attributions = torch.stack( - captum_attributions, dim=1 - ) # (batch_size, top_k, seq_len) - else: - captum_attributions = None - - return { - "prediction": predictions, - "confidence": confidence, - "captum_attributions": captum_attributions, - "label_attention_attributions": label_attention_matrix, - "offset_mapping": tokenize_output.offset_mapping, - "word_ids": tokenize_output.word_ids, - } - else: - return { - "prediction": predictions, - "confidence": confidence, - } - - def save(self, path: Union[str, Path]) -> None: - """Save the complete torchTextClassifiers instance to disk. - - This saves: - - Model configuration - - Tokenizer state - - PyTorch Lightning checkpoint (if trained) - - All other instance attributes - - Args: - path: Directory path where the model will be saved - - Example: - >>> ttc = torchTextClassifiers(tokenizer, model_config) - >>> ttc.train(X_train, y_train, training_config) - >>> ttc.save("my_model") - """ - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - # Save the checkpoint if model has been trained - checkpoint_path = None - if hasattr(self, "lightning_module"): - checkpoint_path = path / "model_checkpoint.ckpt" - # Save the current state as a checkpoint - trainer = pl.Trainer() - trainer.strategy.connect(self.lightning_module) - trainer.save_checkpoint(checkpoint_path) - - # Prepare metadata to save - metadata = { - "model_config": self.model_config.to_dict(), - "ragged_multilabel": self.ragged_multilabel, - "vocab_size": self.vocab_size, - "embedding_dim": self.embedding_dim, - "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes, - "num_classes": self.num_classes, - "checkpoint_path": str(checkpoint_path) if checkpoint_path else None, - "device": str(self.device) if hasattr(self, "device") else None, - "has_value_encoder": self.value_encoder is not None, - } - - # Save metadata - with open(path / "metadata.pkl", "wb") as f: - pickle.dump(metadata, f) - - # Save tokenizer - tokenizer_path = path / "tokenizer.pkl" - with open(tokenizer_path, "wb") as f: - pickle.dump(self.tokenizer, f) - - # Save categorical encoder if present - if self.value_encoder is not None: - with open(path / "value_encoder.pkl", "wb") as f: - pickle.dump(self.value_encoder, f) - - logger.info(f"Model saved successfully to {path}") - - @classmethod - def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers": - """Load a torchTextClassifiers instance from disk. - - Args: - path: Directory path where the model was saved - device: Device to load the model on ('auto', 'cpu', 'cuda', etc.) - - Returns: - Loaded torchTextClassifiers instance - - Example: - >>> loaded_ttc = torchTextClassifiers.load("my_model") - >>> predictions = loaded_ttc.predict(X_test) - """ - path = Path(path) - - if not path.exists(): - raise FileNotFoundError(f"Model directory not found: {path}") - - # Load metadata - with open(path / "metadata.pkl", "rb") as f: - metadata = pickle.load(f) - - # Load tokenizer - with open(path / "tokenizer.pkl", "rb") as f: - tokenizer = pickle.load(f) - - # Reconstruct model_config - model_config = ModelConfig.from_dict(metadata["model_config"]) - - # Load categorical encoder if one was saved - value_encoder = None - if metadata.get("has_value_encoder"): - encoder_path = path / "value_encoder.pkl" - if encoder_path.exists(): - with open(encoder_path, "rb") as f: - value_encoder = pickle.load(f) - - # Create instance - instance = cls( - tokenizer=tokenizer, - model_config=model_config, - ragged_multilabel=metadata["ragged_multilabel"], - value_encoder=value_encoder, - ) - - # Set device - if device == "auto": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device(device) - instance.device = device - - # Load checkpoint if it exists - if metadata["checkpoint_path"]: - checkpoint_path = path / "model_checkpoint.ckpt" - if checkpoint_path.exists(): - # Load the checkpoint with weights_only=False since it's our own trusted checkpoint - instance.lightning_module = TextClassificationModule.load_from_checkpoint( - str(checkpoint_path), - model=instance.pytorch_model, - weights_only=False, - ) - instance.pytorch_model = instance.lightning_module.model.to(device) - instance.checkpoint_path = str(checkpoint_path) - logger.info(f"Model checkpoint loaded from {checkpoint_path}") - else: - logger.warning(f"Checkpoint file not found at {checkpoint_path}") - - logger.info(f"Model loaded successfully from {path}") - return instance - - def __repr__(self): - model_type = ( - self.lightning_module.__repr__() - if hasattr(self, "lightning_module") - else self.pytorch_model.__repr__() - ) - - tokenizer_info = self.tokenizer.__repr__() - - cat_forward_type = ( - self.categorical_var_net.forward_type.name - if self.categorical_var_net is not None - else "None" - ) - - lines = [ - "torchTextClassifiers(", - f" tokenizer = {tokenizer_info},", - f" model = {model_type},", - f" categorical_forward_type = {cat_forward_type},", - f" num_classes = {self.model_config.num_classes},", - f" embedding_dim = {self.embedding_dim},", - ")", - ] - return "\n".join(lines) +import logging +import pickle +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast + +try: + from captum.attr import LayerIntegratedGradients + + HAS_CAPTUM = True +except ImportError: + HAS_CAPTUM = False + + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) +from torch import nn + +from torchTextClassifiers.dataset import TextClassificationDataset +from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, +) +from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput +from torchTextClassifiers.value_encoder import ValueEncoder + +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], +) + + +@dataclass +class ModelConfig: + """Base configuration class for text classifiers.""" + + embedding_dim: int + num_classes: Optional[int | list[int]] = None + categorical_vocabulary_sizes: Optional[List[int]] = None + categorical_embedding_dims: Optional[Union[List[int], int]] = None + attention_config: Optional[AttentionConfig] = None + n_heads_label_attention: Optional[int] = None + aggregation_method: Optional[str] = "mean" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig": + return cls(**data) + + +@dataclass +class TrainingConfig: + num_epochs: int + batch_size: int + lr: float + raw_categorical_inputs: Optional[bool] = True + raw_labels: Optional[bool] = True + loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss()) + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam + scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None + accelerator: str = "auto" + num_workers: int = 12 + patience_early_stopping: int = 3 + dataloader_params: Optional[dict] = None + trainer_params: Optional[dict] = None + optimizer_params: Optional[dict] = None + scheduler_params: Optional[dict] = None + save_path: Optional[str] = "my_ttc" + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self) + # Serialize loss and scheduler as their class names + data["loss"] = self.loss.__class__.__name__ + if self.scheduler is not None: + data["scheduler"] = self.scheduler.__name__ + return data + + +class torchTextClassifiers: + """Generic text classifier framework supporting multiple architectures. + + Given a tokenizer and model configuration, this class initializes: + - Text embedding layer (if needed) + - Categorical variable embedding network (if categorical variables are provided) + - Classification head + The resulting model can be trained using PyTorch Lightning and used for predictions. + + """ + + def __init__( + self, + tokenizer: BaseTokenizer, + model_config: ModelConfig, + ragged_multilabel: bool = False, + value_encoder: Optional[ValueEncoder] = None, + ): + """Initialize the torchTextClassifiers instance. + + Args: + tokenizer: A tokenizer instance for text preprocessing + model_config: Configuration parameters for the text classification model + ragged_multilabel: Whether to use ragged multilabel classification + value_encoder: Optional ValueEncoder for encoding + raw string (or mixed) categorical values to integers. Build it + beforehand from DictEncoder or sklearn LabelEncoder instances and + pass it here. If None, categorical columns in X must already be + integer-encoded. + + Example: + >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers + >>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder + >>> # Build one DictEncoder per categorical feature + >>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))}) + ... for i in range(X_categorical.shape[1])} + >>> encoder = ValueEncoder(encoders) + >>> model_config = ModelConfig( + ... embedding_dim=10, + ... categorical_vocabulary_sizes=encoder.vocabulary_sizes, + ... categorical_embedding_dims=[10, 5], + ... num_classes=10, + ... ) + >>> ttc = torchTextClassifiers( + ... tokenizer=tokenizer, + ... model_config=model_config, + ... value_encoder=encoder, + ... ) + """ + + self.model_config = model_config + self.tokenizer = tokenizer + self.ragged_multilabel = ragged_multilabel + self.value_encoder: ValueEncoder | None = value_encoder + + if hasattr(self.tokenizer, "trained"): + if not self.tokenizer.trained: + raise RuntimeError( + f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier." + ) + + self.vocab_size = tokenizer.vocab_size + self.embedding_dim = model_config.embedding_dim + + if self.value_encoder is not None: + if (model_config.num_classes != self.value_encoder.num_classes) or ( + model_config.categorical_vocabulary_sizes != self.value_encoder.vocabulary_sizes + ): + logger.info( + "Overriding model_config num_classes and/or categorical_vocabulary_sizes with values from value_encoder." + ) + self.categorical_vocabulary_sizes = self.value_encoder.vocabulary_sizes + self.num_classes = self.value_encoder.num_classes + else: + self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes + if model_config.num_classes is None: + raise ValueError( + "num_classes must be specified in the model configuration if no value_encoder is provided." + ) + self.num_classes = model_config.num_classes + + self.enable_label_attention = model_config.n_heads_label_attention is not None + + if self.tokenizer.output_vectorized: + self.token_embedder = None + logger.info( + "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization." + ) + self.embedding_dim = self.tokenizer.output_dim + else: + token_embedder_config = TokenEmbedderConfig( + vocab_size=self.vocab_size, + embedding_dim=self.embedding_dim, + padding_idx=tokenizer.padding_idx, + attention_config=model_config.attention_config, + ) + sentence_embedder_config = SentenceEmbedderConfig( + label_attention_config=LabelAttentionConfig( + n_head=model_config.n_heads_label_attention, + num_classes=model_config.num_classes, + embedding_dim=self.embedding_dim, + ) + if self.enable_label_attention + else None, + aggregation_method=model_config.aggregation_method, + ) + self.token_embedder = TokenEmbedder( + token_embedder_config=token_embedder_config, + ) + self.sentence_embedder = SentenceEmbedder( + sentence_embedder_config=sentence_embedder_config + ) + + classif_head_input_dim = self.embedding_dim + if self.categorical_vocabulary_sizes: + self.categorical_var_net = CategoricalVariableNet( + categorical_vocabulary_sizes=self.categorical_vocabulary_sizes, + categorical_embedding_dims=model_config.categorical_embedding_dims, + text_embedding_dim=self.embedding_dim, + ) + + if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT: + classif_head_input_dim += self.categorical_var_net.output_dim + + else: + self.categorical_var_net = None + + self.classification_head = ClassificationHead( + input_dim=classif_head_input_dim, + num_classes=1 + if self.enable_label_attention + else self.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim) + ) + + self.pytorch_model = TextClassificationModel( + token_embedder=self.token_embedder, + sentence_embedder=self.sentence_embedder, + categorical_variable_net=self.categorical_var_net, + classification_head=self.classification_head, + ) + + @classmethod + def from_model( + cls, + tokenizer: BaseTokenizer, + pytorch_model: nn.Module, + value_encoder: Optional[ValueEncoder] = None, + ragged_multilabel: Optional[bool] = False, + ): + """Initialize torchTextClassifiers from a pre-built PyTorch model. + + This method allows users to create a torchTextClassifiers instance using a pre-built PyTorch model that may not follow the standard architecture expected by the main constructor. The provided model should be compatible with the input format used in the predict method (i.e., it should accept tokenized text and categorical variables as input). + + Args: + tokenizer: A tokenizer instance for text preprocessing + pytorch_model: A pre-built PyTorch model to be used for predictions + value_encoder: Optional ValueEncoder for encoding raw string (or mixed) categorical values to integers. Build it beforehand from DictEncoder or sklearn LabelEncoder instances and pass it here. If None, categorical columns in X must already be integer-encoded. + + Returns: + An instance of torchTextClassifiers initialized with the provided model and tokenizer. + """ + instance = cls.__new__(cls) + instance.tokenizer = tokenizer + instance.pytorch_model = pytorch_model + instance.num_classes = pytorch_model.num_classes + instance.categorical_var_net = cast( + Optional[CategoricalVariableNet], pytorch_model.categorical_variable_net + ) + instance.value_encoder = value_encoder + instance.ragged_multilabel = ragged_multilabel + instance._custom_model = True + instance.enable_label_attention = False + instance.categorical_vocabulary_sizes = ( + instance.categorical_var_net.categorical_vocabulary_sizes + if instance.categorical_var_net is not None + else None + ) + return instance + + def train( + self, + X_train: np.ndarray, + y_train: np.ndarray, + training_config: TrainingConfig, + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None, + verbose: bool = False, + ) -> None: + """Train the classifier using PyTorch Lightning. + + This method handles the complete training process including: + - Data validation and preprocessing + - Dataset and DataLoader creation + - PyTorch Lightning trainer setup with callbacks + - Model training with early stopping + - Best model loading after training + + Note on Checkpoints: + After training, the best model checkpoint is automatically loaded. + This checkpoint contains the full training state (model weights, + optimizer, and scheduler state). Loading uses weights_only=False + as the checkpoint is self-generated and trusted. + + Args: + X_train: Training input data + y_train: Training labels + X_val: Validation input data + y_val: Validation labels + training_config: Configuration parameters for training + verbose: Whether to print training progress information + + + Example: + + >>> training_config = TrainingConfig( + ... lr=1e-3, + ... batch_size=4, + ... num_epochs=1, + ... ) + >>> ttc.train( + ... X_train=X, + ... y_train=Y, + ... X_val=X, + ... y_val=Y, + ... training_config=training_config, + ... ) + """ + + # Input validation + X_train, y_train = self._check_XY( + X_train, y_train, training_config.raw_categorical_inputs, training_config.raw_labels + ) + + if X_val is not None: + assert y_val is not None, "y_val must be provided if X_val is provided." + if y_val is not None: + assert X_val is not None, "X_val must be provided if y_val is provided." + + X_val: Optional[Dict[str, Any]] = None + if X_val is not None and y_val is not None: + X_val, y_val = self._check_XY(X_val, y_val) + + if ( + (X_train["categorical_variables"] is not None) + and (X_val is not None) + and (X_val["categorical_variables"] is not None) + ): + assert ( + X_train["categorical_variables"].ndim > 1 + and X_train["categorical_variables"].shape[1] + == X_val["categorical_variables"].shape[1] + or X_val["categorical_variables"].ndim == 1 + ), "X_train and X_val must have the same number of columns." + + if verbose: + logger.info("Starting training process...") + + if training_config.accelerator == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(training_config.accelerator) + + self.device = device + + optimizer_params = {"lr": training_config.lr} + if training_config.optimizer_params is not None: + optimizer_params.update(training_config.optimizer_params) + + if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel: + logger.warning( + "āš ļø You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks." + ) + + self.lightning_module = TextClassificationModule( + model=self.pytorch_model, + loss=training_config.loss, + optimizer=training_config.optimizer, + optimizer_params=optimizer_params, + scheduler=training_config.scheduler, + scheduler_params=training_config.scheduler_params + if training_config.scheduler_params + else {}, + scheduler_interval="epoch", + ) + + self.pytorch_model.to(self.device) + + if verbose: + logger.info(f"Running on: {device}") + + train_dataset = TextClassificationDataset( + texts=X_train["text"], + categorical_variables=X_train["categorical_variables"], # None if no cat vars + tokenizer=self.tokenizer, + labels=y_train.tolist(), + ragged_multilabel=self.ragged_multilabel, + ) + train_dataloader = train_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=True, + **training_config.dataloader_params if training_config.dataloader_params else {}, + ) + + if X_val is not None and y_val is not None: + val_dataset = TextClassificationDataset( + texts=X_val["text"], + categorical_variables=X_val["categorical_variables"], # None if no cat vars + tokenizer=self.tokenizer, + labels=y_val, + ragged_multilabel=self.ragged_multilabel, + ) + val_dataloader = val_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=False, + **training_config.dataloader_params if training_config.dataloader_params else {}, + ) + else: + val_dataloader = None + + # Setup trainer + callbacks = [ + ModelCheckpoint( + monitor="val_loss" if val_dataloader is not None else "train_loss", + save_top_k=1, + save_last=False, + mode="min", + ), + EarlyStopping( + monitor="val_loss" if val_dataloader is not None else "train_loss", + patience=training_config.patience_early_stopping, + mode="min", + ), + LearningRateMonitor(logging_interval="step"), + ] + + trainer_params = { + "accelerator": training_config.accelerator, + "callbacks": callbacks, + "max_epochs": training_config.num_epochs, + "num_sanity_val_steps": 2, + "strategy": "auto", + "log_every_n_steps": 1, + "enable_progress_bar": True, + } + + if training_config.trainer_params is not None: + trainer_params.update(training_config.trainer_params) + + trainer = pl.Trainer(**trainer_params) + + torch.cuda.empty_cache() + torch.set_float32_matmul_precision("medium") + + if verbose: + logger.info("Launching training...") + start = time.time() + + trainer.fit(self.lightning_module, train_dataloader, val_dataloader) + + if verbose: + end = time.time() + logger.info(f"Training completed in {end - start:.2f} seconds.") + + best_model_path = trainer.checkpoint_callback.best_model_path + self.checkpoint_path = best_model_path + + self.lightning_module = TextClassificationModule.load_from_checkpoint( + best_model_path, + model=self.pytorch_model, + loss=training_config.loss, + weights_only=False, # Required: checkpoint contains optimizer/scheduler state + ) + + self.pytorch_model = self.lightning_module.model.to(self.device) + + self.save_path = training_config.save_path + self.save(self.save_path) + + self.lightning_module.eval() + + def _check_XY( + self, X: np.ndarray, Y: np.ndarray, raw_categorical_inputs, raw_labels + ) -> Tuple[Dict[str, Any], np.ndarray]: + X_checked = self._check_X(X, raw_categorical_inputs) + Y_checked = self._check_Y(Y, raw_labels) + + if X_checked["text"].shape[0] != len(Y_checked): + raise ValueError("X_train and y_train must have the same number of observations.") + + return X_checked, Y_checked + + @staticmethod + def _check_text_col(X): + assert isinstance( + X, np.ndarray + ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables." + + try: + if X.ndim > 1: + text = X[:, 0].astype(str) + else: + text = X[:].astype(str) + except ValueError: + logger.error("The first column of X must be castable in string format.") + + return text + + def _check_categorical_variables( + self, X: np.ndarray, raw_categorical_inputs: bool + ) -> np.ndarray: + """Validate and encode categorical variables from X. + + If a ``value_encoder`` was provided at initialization, raw string + or mixed values are encoded to integers via that encoder. Otherwise the + categorical columns must already be integer-encodable. + + Args: + X: Full input array whose first column is text and whose remaining + columns are categorical variables. + + Returns: + Integer-encoded categorical array of shape (N, n_cat_features). + + Raises: + ValueError: If the number of categorical features does not match the + model configuration, if values exceed vocabulary bounds, or if + values cannot be cast to integers and no encoder was provided. + """ + assert self.categorical_var_net is not None + + num_cat_vars = X.shape[1] - 1 if X.ndim > 1 else 0 + + if num_cat_vars != self.categorical_var_net.num_categorical_features: + raise ValueError( + f"X must have the same number of categorical variables as the number of " + f"embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})." + ) + + if raw_categorical_inputs: + if self.value_encoder is None: + raise ValueError( + "Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers." + ) + categorical_variables = self.value_encoder.transform(X[:, 1:]).astype(int) + else: + categorical_variables = X[:, 1:].astype(int) + + for j in range(num_cat_vars): + max_cat_value = categorical_variables[:, j].max() + if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]: + raise ValueError( + f"Categorical variable at index {j} has value {max_cat_value} which exceeds " + f"the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}." + ) + + return categorical_variables + + def _check_X(self, X: np.ndarray, raw_categorical_inputs: bool) -> Dict[str, Any]: + text = self._check_text_col(X) + + categorical_variables = None + if self.categorical_var_net is not None: + categorical_variables = self._check_categorical_variables(X, raw_categorical_inputs) + + return {"text": text, "categorical_variables": categorical_variables} + + def _check_Y(self, Y, raw_labels: bool) -> np.ndarray: + if self.ragged_multilabel: + assert isinstance( + Y, list + ), "Y must be a list of lists for ragged multilabel classification." + for row in Y: + assert isinstance(row, list), "Each element of Y must be a list of labels." + + return Y + + else: + assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." + assert ( + len(Y.shape) == 1 or len(Y.shape) == 2 + ), "Y must be a numpy array of shape (N,) or (N, num_labels)." + + if raw_labels: + if self.value_encoder is None: + raise ValueError( + "Raw label encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw labels to integers." + ) + Y = self.value_encoder.transform_labels(Y) + Y = Y.astype(int) + + if isinstance(self.num_classes, list): + num_classes_arr = np.array(self.num_classes) + + print(Y, num_classes_arr) + if (Y.max(axis=0) >= num_classes_arr).any() or Y.min() < 0: + raise ValueError( + f"Y contains class labels outside the expected per-level ranges " + f"[0, {[nc - 1 for nc in self.num_classes]}]." + ) + elif Y.max() >= self.num_classes or Y.min() < 0: + raise ValueError( + f"Y contains class labels outside the range [0, {self.num_classes - 1}]." + ) + + return Y + + def predict( + self, + X_test: np.ndarray, + raw_categorical_inputs: bool = True, + top_k=1, + explain_with_label_attention: bool = False, + explain_with_captum=False, + ): + """ + Args: + X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables + top_k (int): for each sentence, return the top_k most likely predictions (default: 1) + explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False) + explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False) + + Returns: A dictionary containing the following fields: + - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query. + - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores. + - if explain is True: + - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text. + """ + + explain = explain_with_label_attention or explain_with_captum + if explain: + return_offsets_mapping = True # to be passed to the tokenizer + return_word_ids = True + if self.pytorch_model.token_embedder is None: + raise RuntimeError( + "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." + ) + else: + if explain_with_captum: + if not HAS_CAPTUM: + raise ImportError( + "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." + ) + lig = LayerIntegratedGradients( + self.pytorch_model, self.pytorch_model.token_embedder.embedding_layer + ) # initialize a Captum layer gradient integrator + if explain_with_label_attention: + if not self.enable_label_attention: + raise RuntimeError( + "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain." + ) + else: + return_offsets_mapping = False + return_word_ids = False + + X_test = self._check_X(X_test, raw_categorical_inputs) + text = X_test["text"] + categorical_variables = X_test["categorical_variables"] + + self.pytorch_model.eval().cpu() + + tokenize_output = self.tokenizer.tokenize( + text.tolist(), + return_offsets_mapping=return_offsets_mapping, + return_word_ids=return_word_ids, + ) + + if not isinstance(tokenize_output, TokenizerOutput): + raise TypeError( + f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method." + ) + + encoded_text = tokenize_output.input_ids # (batch_size, seq_len) + attention_mask = tokenize_output.attention_mask # (batch_size, seq_len) + + if categorical_variables is not None: + categorical_vars = torch.tensor( + categorical_variables, dtype=torch.float32 + ) # (batch_size, num_categorical_features) + else: + categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32) + + model_output = self.pytorch_model( + encoded_text, + attention_mask, + categorical_vars, + return_label_attention_matrix=explain_with_label_attention, + ) # forward pass, contains the prediction scores (len(text), num_classes) + + # Multilevel custom model: returns a list of per-level logit tensors. + # Each level may have a different number of classes, so we process them + # separately then stack into (N, top_k, n_levels) before decoding. + if isinstance(model_output, list): + int_preds_per_level: List[np.ndarray] = [] + conf_per_level: List[torch.Tensor] = [] + for level_logits in cast(List[torch.Tensor], model_output): + scores = level_logits.detach().cpu().softmax(dim=-1) + level_topk = torch.topk(scores, k=top_k, dim=-1) + int_preds_per_level.append(level_topk.indices.numpy()) # (N, top_k) + conf_per_level.append(level_topk.values) # (N, top_k) + + # (N, top_k, n_levels) + int_preds_stacked = np.stack(int_preds_per_level, axis=-1) + conf_stacked = torch.stack(conf_per_level, dim=-1) + + if self.value_encoder is not None: + predictions = self.value_encoder.inverse_transform_labels(int_preds_stacked) + else: + predictions = int_preds_stacked + + return { + "prediction": predictions, + "confidence": torch.round(conf_stacked, decimals=2), + } + + pred = ( + model_output["logits"] if explain_with_label_attention else model_output + ) # (batch_size, num_classes) + + label_attention_matrix = ( + model_output["label_attention_matrix"] if explain_with_label_attention else None + ) + + label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities + + label_scores_topk = torch.topk(label_scores, k=top_k, dim=1) + + integer_predictions = label_scores_topk.indices # integer class indices (needed for captum) + if self.value_encoder is not None: + predictions = self.value_encoder.inverse_transform_labels(integer_predictions.numpy()) + else: + predictions = integer_predictions + + confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores + + if explain: + if explain_with_captum: + # Captum explanations + captum_attributions = [] + for k in range(top_k): + attributions = lig.attribute( + (encoded_text, attention_mask, categorical_vars), + target=integer_predictions[:, k], + ) # (batch_size, seq_len) + attributions = attributions.sum(dim=-1) + captum_attributions.append(attributions.detach().cpu()) + + captum_attributions = torch.stack( + captum_attributions, dim=1 + ) # (batch_size, top_k, seq_len) + else: + captum_attributions = None + + return { + "prediction": predictions, + "confidence": confidence, + "captum_attributions": captum_attributions, + "label_attention_attributions": label_attention_matrix, + "offset_mapping": tokenize_output.offset_mapping, + "word_ids": tokenize_output.word_ids, + } + else: + return { + "prediction": predictions, + "confidence": confidence, + } + + def save(self, path: Union[str, Path]) -> None: + """Save the complete torchTextClassifiers instance to disk. + + This saves: + - Model configuration + - Tokenizer state + - PyTorch Lightning checkpoint (if trained) + - All other instance attributes + + Args: + path: Directory path where the model will be saved + + Example: + >>> ttc = torchTextClassifiers(tokenizer, model_config) + >>> ttc.train(X_train, y_train, training_config) + >>> ttc.save("my_model") + """ + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + is_custom_model = getattr(self, "_custom_model", False) + + # Custom models: save architecture as pickle + weights as state_dict. + # Standard models: save a full PyTorch Lightning checkpoint. + checkpoint_path = None + if is_custom_model: + with open(path / "pytorch_model.pkl", "wb") as f: + pickle.dump(self.pytorch_model, f) + torch.save(self.pytorch_model.state_dict(), path / "model_weights.pt") + elif hasattr(self, "lightning_module"): + checkpoint_path = path / "model_checkpoint.ckpt" + trainer = pl.Trainer() + trainer.strategy.connect(self.lightning_module) + trainer.save_checkpoint(checkpoint_path) + + metadata: Dict[str, Any] = { + "is_custom_model": is_custom_model, + "loss": self.lightning_module.loss if hasattr(self, "lightning_module") else None, + "ragged_multilabel": self.ragged_multilabel, + "num_classes": self.num_classes, + "checkpoint_path": str(checkpoint_path) if checkpoint_path else None, + "device": str(self.device) if hasattr(self, "device") else None, + "has_value_encoder": self.value_encoder is not None, + } + + if not is_custom_model: + metadata.update( + { + "model_config": self.model_config.to_dict(), + "vocab_size": self.vocab_size, + "embedding_dim": self.embedding_dim, + "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes, + } + ) + + with open(path / "metadata.pkl", "wb") as f: + pickle.dump(metadata, f) + + tokenizer_path = path / "tokenizer.pkl" + with open(tokenizer_path, "wb") as f: + pickle.dump(self.tokenizer, f) + + if self.value_encoder is not None: + with open(path / "value_encoder.pkl", "wb") as f: + pickle.dump(self.value_encoder, f) + + logger.info(f"Model saved successfully to {path}") + + @classmethod + def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers": + """Load a torchTextClassifiers instance from disk. + + Args: + path: Directory path where the model was saved + device: Device to load the model on ('auto', 'cpu', 'cuda', etc.) + + Returns: + Loaded torchTextClassifiers instance + + Example: + >>> loaded_ttc = torchTextClassifiers.load("my_model") + >>> predictions = loaded_ttc.predict(X_test) + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Model directory not found: {path}") + + with open(path / "metadata.pkl", "rb") as f: + metadata = pickle.load(f) + + with open(path / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + + resolved_device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "auto" + else torch.device(device) + ) + + value_encoder = None + if metadata.get("has_value_encoder"): + encoder_path = path / "value_encoder.pkl" + if encoder_path.exists(): + with open(encoder_path, "rb") as f: + value_encoder = pickle.load(f) + + if metadata.get("is_custom_model", False): + with open(path / "pytorch_model.pkl", "rb") as f: + pytorch_model = pickle.load(f) + weights_path = path / "model_weights.pt" + if weights_path.exists(): + pytorch_model.load_state_dict(torch.load(weights_path, weights_only=True)) + logger.info(f"Model weights loaded from {weights_path}") + instance = cls.from_model( + tokenizer=tokenizer, + pytorch_model=pytorch_model, + value_encoder=value_encoder, + ragged_multilabel=metadata["ragged_multilabel"], + ) + instance.device = resolved_device + pytorch_model.to(resolved_device) + logger.info(f"Model loaded successfully from {path}") + return instance + + model_config = ModelConfig.from_dict(metadata["model_config"]) + + instance = cls( + tokenizer=tokenizer, + model_config=model_config, + ragged_multilabel=metadata["ragged_multilabel"], + value_encoder=value_encoder, + ) + + instance.device = resolved_device + + if metadata.get("checkpoint_path"): + checkpoint_path = path / "model_checkpoint.ckpt" + if checkpoint_path.exists(): + loss = metadata.get("loss") or torch.nn.CrossEntropyLoss() + instance.lightning_module = TextClassificationModule.load_from_checkpoint( + str(checkpoint_path), + model=instance.pytorch_model, + loss=loss, + weights_only=False, + ) + instance.pytorch_model = instance.lightning_module.model.to(resolved_device) + instance.checkpoint_path = str(checkpoint_path) + logger.info(f"Model checkpoint loaded from {checkpoint_path}") + else: + logger.warning(f"Checkpoint file not found at {checkpoint_path}") + + logger.info(f"Model loaded successfully from {path}") + return instance + + def __repr__(self): + model_type = ( + self.lightning_module.__repr__() + if hasattr(self, "lightning_module") + else self.pytorch_model.__repr__() + ) + + tokenizer_info = self.tokenizer.__repr__() + + cat_forward_type = ( + self.categorical_var_net.forward_type.name + if self.categorical_var_net is not None + else "None" + ) + + lines = [ + "torchTextClassifiers(", + f" tokenizer = {tokenizer_info},", + f" model = {model_type},", + f" categorical_forward_type = {cat_forward_type},", + f" num_classes = {self.model_config.num_classes},", + f" embedding_dim = {self.embedding_dim},", + ")", + ] + return "\n".join(lines) diff --git a/torchTextClassifiers/value_encoder/value_encoder.py b/torchTextClassifiers/value_encoder/value_encoder.py index 8215093..7ca946b 100644 --- a/torchTextClassifiers/value_encoder/value_encoder.py +++ b/torchTextClassifiers/value_encoder/value_encoder.py @@ -56,14 +56,21 @@ class ValueEncoder: def __init__( self, - label_encoder: DictEncoder | LabelEncoder, + label_encoder: DictEncoder | LabelEncoder | list[DictEncoder | LabelEncoder], categorical_encoders: Optional[dict[str, DictEncoder | LabelEncoder]] = None, ): self.categorical_encoders = categorical_encoders - if not isinstance(label_encoder, (DictEncoder, LabelEncoder)): + if isinstance(label_encoder, list): + for enc in label_encoder: + if not isinstance(enc, (DictEncoder, LabelEncoder)): + raise TypeError( + "Each element of label_encoder list must be a DictEncoder or LabelEncoder, " + f"got {type(enc)}" + ) + elif not isinstance(label_encoder, (DictEncoder, LabelEncoder)): raise TypeError( - "label_encoder must be a DictEncoder or LabelEncoder instance, " + "label_encoder must be a DictEncoder, LabelEncoder, or list thereof, " f"got {type(label_encoder)}" ) self.label_encoder = label_encoder @@ -88,10 +95,21 @@ def vocabulary_sizes(self) -> list[int]: @property def num_classes(self) -> int: """Number of unique classes in the label encoder, if provided.""" - if isinstance(self.label_encoder, DictEncoder): - return len(self.label_encoder.mapping) - elif hasattr(self.label_encoder, "classes_"): - return len(self.label_encoder.classes_) + + def _get_num_classes(enc): + if isinstance(enc, DictEncoder): + return len(enc.mapping) + elif hasattr(enc, "classes_"): + return len(enc.classes_) + else: + raise TypeError(f"Unsupported encoder type: {type(enc)}") + + if isinstance(self.label_encoder, DictEncoder) or isinstance( + self.label_encoder, LabelEncoder + ): + return _get_num_classes(self.label_encoder) + elif isinstance(self.label_encoder, list): + return [_get_num_classes(enc) for enc in self.label_encoder] else: raise TypeError(f"Unsupported label encoder type: {type(self.label_encoder)}") @@ -137,38 +155,85 @@ def transform_labels(self, y_labels: np.ndarray) -> np.ndarray: Values are converted to strings before lookup. Unknown values raise a ValueError. Args: - y_labels: Array of shape (N,) with label values. + y_labels: Array of shape (N,) for a single encoder, or (N, n_levels) for a list + of encoders. Returns: - Integer-encoded array of shape (N,), dtype int64. + Integer-encoded array of shape (N,) or (N, n_levels), dtype int64. Raises: ValueError: If any label value was not seen during fitting. """ - col = y_labels.astype(str) - encoded = self.label_encoder.transform(col) - try: - return encoded.astype(np.int64) - except (TypeError, ValueError): - unknown = [v for v, e in zip(col.tolist(), encoded.tolist()) if e is None] - raise ValueError( - f"Unknown values in label encoder: {unknown}. " - "These values were not seen during fitting." - ) + def _encode_col(enc, col, level_name="label encoder"): + encoded = enc.transform(col) + try: + return encoded.astype(np.int64) + except (TypeError, ValueError): + unknown = [v for v, e in zip(col.tolist(), encoded.tolist()) if e is None] + raise ValueError( + f"Unknown values in {level_name}: {unknown}. " + "These values were not seen during fitting." + ) + + if isinstance(self.label_encoder, list): + if y_labels.ndim == 1: + y_labels = y_labels.reshape(-1, 1) + result = np.empty(y_labels.shape, dtype=np.int64) + for idx, enc in enumerate(self.label_encoder): + result[:, idx] = _encode_col( + enc, y_labels[:, idx].astype(str), f"label encoder at level {idx}" + ) + return result + + return _encode_col(self.label_encoder, y_labels.astype(str)) def inverse_transform_labels(self, y_encoded: np.ndarray) -> np.ndarray: """Decode integer-encoded labels back to original values. Args: - y_encoded: Array of shape (N,) with integer-encoded labels. + y_encoded: Supported shapes: + - Single encoder: (N,) or (N, top_k) + - List of encoders: (N, n_levels), or (N, top_k, n_levels) Returns: - Array of shape (N,) with original label values. + Decoded array with the same shape as ``y_encoded``. Raises: ValueError: If any encoded label value was not seen during fitting. """ + def _decode_col(enc, col: np.ndarray) -> np.ndarray: + """Decode a flat 1-D array using a single encoder.""" + if isinstance(enc, DictEncoder): + return np.vectorize(enc.inverse_mapping.get)(col) + elif hasattr(enc, "inverse_transform"): + return enc.inverse_transform(col) + else: + raise TypeError(f"Unsupported label encoder type: {type(enc)}") + + if isinstance(self.label_encoder, list): + if y_encoded.ndim == 1: + y_encoded = y_encoded.reshape(-1, 1) + + if y_encoded.ndim == 2: + # (N, n_levels) + result = np.empty(y_encoded.shape, dtype=object) + for idx, enc in enumerate(self.label_encoder): + result[:, idx] = _decode_col(enc, y_encoded[:, idx]) + return result + + if y_encoded.ndim == 3: + # (N, top_k, n_levels) + n, top_k = y_encoded.shape[0], y_encoded.shape[1] + result = np.empty(y_encoded.shape, dtype=object) + for idx, enc in enumerate(self.label_encoder): + flat = y_encoded[:, :, idx].ravel() + result[:, :, idx] = _decode_col(enc, flat).reshape(n, top_k) + return result + + raise ValueError( + f"Expected 1-D, 2-D, or 3-D array for list encoder, got {y_encoded.ndim}-D" + ) + if isinstance(self.label_encoder, DictEncoder): - inverse_mapping = self.label_encoder.inverse_mapping - return np.vectorize(inverse_mapping.get)(y_encoded) + return np.vectorize(self.label_encoder.inverse_mapping.get)(y_encoded) elif hasattr(self.label_encoder, "inverse_transform"): shape = y_encoded.shape result = self.label_encoder.inverse_transform(y_encoded.ravel()) From e8c4629d1cfb29d4852d6a76ec082983d3b6d69e Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 28 Apr 2026 15:50:11 +0000 Subject: [PATCH 4/6] fix: more general typing for model --- torchTextClassifiers/model/lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index 886b347..1b9ed94 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -1,9 +1,8 @@ import pytorch_lightning as pl import torch +from torch import nn from torchmetrics import Accuracy -from .model import TextClassificationModel - # ============================================================================ # PyTorch Lightning Module # ============================================================================ @@ -14,7 +13,7 @@ class TextClassificationModule(pl.LightningModule): def __init__( self, - model: TextClassificationModel, + model: nn.Module, loss, optimizer, optimizer_params, From ea00b8b247301d2d612923d495ec58833e084f66 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 28 Apr 2026 15:54:46 +0000 Subject: [PATCH 5/6] feat!(contrib): add contrib folder for custom complex archi - common in OS packages - add also an example script for multilevel classif --- examples/multilevel_example.py | 110 +----------- torchTextClassifiers/contrib/__init__.py | 13 ++ torchTextClassifiers/contrib/multilevel.py | 191 +++++++++++++++++++++ 3 files changed, 210 insertions(+), 104 deletions(-) create mode 100644 torchTextClassifiers/contrib/__init__.py create mode 100644 torchTextClassifiers/contrib/multilevel.py diff --git a/examples/multilevel_example.py b/examples/multilevel_example.py index e04b943..f6b347f 100644 --- a/examples/multilevel_example.py +++ b/examples/multilevel_example.py @@ -1,13 +1,15 @@ -from typing import Optional +from typing import cast import numpy as np import pandas as pd import torch from sklearn.preprocessing import LabelEncoder -from torch import nn -import torchTextClassifiers from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.contrib import ( + MultiLevelCrossEntropyLoss, + MultiLevelTextClassificationModel, +) from torchTextClassifiers.dataset import TextClassificationDataset from torchTextClassifiers.model import TextClassificationModule from torchTextClassifiers.model.components import ( @@ -167,106 +169,6 @@ all_classification_heads.append(classification_head) -class MultiLevelTextClassificationModel(nn.Module): - def __init__( - self, - token_embedder: TokenEmbedder, - sentence_embedders: list[SentenceEmbedder], - classification_heads: list[ClassificationHead], - categorical_variable_net: CategoricalVariableNet, - ): - super().__init__() - self.token_embedder = token_embedder - self.sentence_embedders = sentence_embedders - self.classification_heads = classification_heads - self.categorical_variable_net = categorical_variable_net - self.num_classes: list[int] = [ - self.sentence_embedders[i].label_attention_config.num_classes - if self.sentence_embedders[i].label_attention_config is not None - else self.classification_heads[i].num_classes - for i in range(len(self.sentence_embedders)) - ] - - def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): - token_embed_output = self.token_embedder(input_ids, attention_mask) - x_token = token_embed_output["token_embeddings"] - x_cat = self.categorical_variable_net(categorical_vars) # (bs, cat_emb_dim) - - print(f"Token embeddings shape: {x_token.shape}") - print( - f"x_cat shape: {x_cat.shape}" - ) # Debugging line to check shape of categorical variable embeddings - outputs = [] - for sentence_embedder, classification_head in zip( - self.sentence_embedders, self.classification_heads - ): - if sentence_embedder.label_attention_config is not None: - num_classes = sentence_embedder.label_attention_config.num_classes - x_cat_level = x_cat.unsqueeze(1).expand(-1, num_classes, -1) - else: - x_cat_level = x_cat - - print( - f"x_cat_level shape: {x_cat_level.shape}" - ) # Debugging line to check shape of categorical variable embeddings after expansion - sentence_embedding = sentence_embedder( - token_embeddings=x_token, attention_mask=attention_mask - )[ - "sentence_embedding" - ] # (bs, embedding_dim) or (bs, num_classes, embedding_dim) if label attention - - print( - f"Sentence embedding shape: {sentence_embedding.shape}" - ) # Debugging line to check shape of sentence embeddings - if ( - self.categorical_variable_net.forward_type - == CategoricalForwardType.AVERAGE_AND_CONCAT - or self.categorical_variable_net.forward_type - == CategoricalForwardType.CONCATENATE_ALL - ): - x_combined = torch.cat((sentence_embedding, x_cat_level), dim=-1) - else: - assert ( - self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT - ) - - x_combined = sentence_embedding + x_cat_level - - print( - f"x_combined shape: {x_combined.shape}" - ) # Debugging line to check shape of combined features before classification head - output = classification_head(x_combined).squeeze(-1) - outputs.append(output) - print( - f"Output shape for current level: {output.shape}" - ) # Debugging line to check shape of output logits for current level - - return outputs - - -class MultiLevelCrossEntropyLoss(nn.Module): - def __init__(self, num_classes: Optional[list[int]] = None): - super().__init__() - self.num_classes = num_classes - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, outputs, labels): - total_loss = 0 - for idx, output in enumerate(outputs): - label = labels[:, idx] # (batch_size,) - if self.num_classes is not None: - total_loss += self.loss_fn(output.squeeze(), label) * self.num_classes[idx] - else: - total_loss += self.loss_fn(output.squeeze(), label) - - if self.num_classes is not None: - total_weight = sum(self.num_classes) - else: - total_weight = len(outputs) - - return total_loss / total_weight # average loss across levels - - model = MultiLevelTextClassificationModel( token_embedder=token_embedder, sentence_embedders=all_sentence_embedders, @@ -300,7 +202,7 @@ def forward(self, outputs, labels): batch_size=6, lr=1e-3, raw_categorical_inputs=True, - loss=MultiLevelCrossEntropyLoss(num_classes=value_encoder.num_classes), + loss=MultiLevelCrossEntropyLoss(num_classes=cast(list[int], value_encoder.num_classes)), ) ttc.train( diff --git a/torchTextClassifiers/contrib/__init__.py b/torchTextClassifiers/contrib/__init__.py new file mode 100644 index 0000000..1f4bef9 --- /dev/null +++ b/torchTextClassifiers/contrib/__init__.py @@ -0,0 +1,13 @@ +"""contrib: example custom architectures for torchTextClassifiers. + +These classes are reference implementations that demonstrate how to build custom +PyTorch models compatible with the ``torchTextClassifiers.from_model`` entry point. +They are not part of the core API and may evolve independently. +""" + +from .multilevel import MultiLevelCrossEntropyLoss, MultiLevelTextClassificationModel + +__all__ = [ + "MultiLevelTextClassificationModel", + "MultiLevelCrossEntropyLoss", +] diff --git a/torchTextClassifiers/contrib/multilevel.py b/torchTextClassifiers/contrib/multilevel.py new file mode 100644 index 0000000..7a39755 --- /dev/null +++ b/torchTextClassifiers/contrib/multilevel.py @@ -0,0 +1,191 @@ +"""Multi-level (multi-task) classification architecture. + +Provides a custom model and a matching loss function for tasks that require +predicting several classification targets simultaneously from the same text +input — for example, hierarchical category codes (level 1, level 2, level 3) +predicted in a single forward pass. + +These classes are compatible with ``torchTextClassifiers.from_model`` and can +serve as a starting point for your own multi-task architectures. + +Example usage:: + + from torchTextClassifiers import torchTextClassifiers + from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, + ) + from torchTextClassifiers.model import TextClassificationModule + + model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=sentence_embedders, # one per level + classification_heads=classification_heads, # one per level + categorical_variable_net=categorical_var_net, + ) + + # Train with PyTorch Lightning directly + module = TextClassificationModule( + model=model, + loss=MultiLevelCrossEntropyLoss(), + optimizer=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + ) + + # Or wrap with the high-level API for predict() / save() / load() + classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, + ) +""" + +from typing import Optional + +import torch +from torch import nn + +from torchTextClassifiers.model.components import ( + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + SentenceEmbedder, + TokenEmbedder, +) + + +class MultiLevelTextClassificationModel(nn.Module): + """Multi-task text classifier that predicts several classes in one forward pass. + + Each classification level has its own ``SentenceEmbedder`` and + ``ClassificationHead`` but they all share the same ``TokenEmbedder``, so + the token-level representations are computed only once. + + Attributes: + num_classes: List of class counts, one entry per level. Required by + ``torchTextClassifiers.from_model``. + categorical_variable_net: The categorical embedding module (may be + ``None``). Required by ``torchTextClassifiers.from_model``. + + Args: + token_embedder: Shared token embedding module. + sentence_embedders: One ``SentenceEmbedder`` per classification level. + classification_heads: One ``ClassificationHead`` per level. + categorical_variable_net: Categorical feature embedding module. + """ + + def __init__( + self, + token_embedder: TokenEmbedder, + sentence_embedders: list[SentenceEmbedder], + classification_heads: list[ClassificationHead], + categorical_variable_net: CategoricalVariableNet, + ): + super().__init__() + self.token_embedder = token_embedder + self.sentence_embedders = nn.ModuleList(sentence_embedders) + self.classification_heads = nn.ModuleList(classification_heads) + self.categorical_variable_net = categorical_variable_net + self.num_classes: list[int] = [ + se.label_attention_config.num_classes + if se.label_attention_config is not None + else ch.num_classes + for se, ch in zip(sentence_embedders, classification_heads) + ] + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + categorical_vars: Optional[torch.Tensor] = None, + **kwargs, + ) -> list[torch.Tensor]: + """Run a forward pass and return one logit tensor per level. + + Args: + input_ids: Tokenised text, shape ``(batch, seq_len)``. + attention_mask: Padding mask, shape ``(batch, seq_len)``. + categorical_vars: Integer-encoded categorical features, + shape ``(batch, n_cats)``. May be ``None`` if no categorical + features are used. + + Returns: + List of raw logit tensors, one per classification level. + Each tensor has shape ``(batch, num_classes_at_level)``. + """ + token_embed_output = self.token_embedder(input_ids, attention_mask) + x_token = token_embed_output["token_embeddings"] + x_cat = self.categorical_variable_net(categorical_vars) + + outputs = [] + for sentence_embedder, classification_head in zip( + self.sentence_embedders, self.classification_heads + ): + if sentence_embedder.label_attention_config is not None: + num_cls = sentence_embedder.label_attention_config.num_classes + x_cat_level = x_cat.unsqueeze(1).expand(-1, num_cls, -1) + else: + x_cat_level = x_cat + + sentence_embedding = sentence_embedder( + token_embeddings=x_token, attention_mask=attention_mask + )["sentence_embedding"] + + fwd = self.categorical_variable_net.forward_type + if fwd in ( + CategoricalForwardType.AVERAGE_AND_CONCAT, + CategoricalForwardType.CONCATENATE_ALL, + ): + x_combined = torch.cat((sentence_embedding, x_cat_level), dim=-1) + else: + assert fwd == CategoricalForwardType.SUM_TO_TEXT + x_combined = sentence_embedding + x_cat_level + + outputs.append(classification_head(x_combined).squeeze(-1)) + + return outputs + + +class MultiLevelCrossEntropyLoss(nn.Module): + """Weighted cross-entropy loss across multiple classification levels. + + Averages the per-level cross-entropy losses, optionally weighting each + level by its number of classes so that finer-grained levels contribute + more to the total gradient. + + Args: + num_classes: If provided, level ``i`` is weighted by + ``num_classes[i] / sum(num_classes)``. If ``None``, all levels + are weighted equally. + + Example:: + + loss_fn = MultiLevelCrossEntropyLoss(num_classes=[5, 20, 100]) + # or unweighted: + loss_fn = MultiLevelCrossEntropyLoss() + """ + + def __init__(self, num_classes: Optional[list[int]] = None): + super().__init__() + self.num_classes = num_classes + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, outputs: list[torch.Tensor], labels: torch.Tensor) -> torch.Tensor: + """Compute the weighted average loss. + + Args: + outputs: List of logit tensors ``(batch, num_classes_i)`` returned + by ``MultiLevelTextClassificationModel``. + labels: Integer label tensor of shape ``(batch, n_levels)``. + Column ``i`` contains the ground-truth label for level ``i``. + + Returns: + Scalar loss tensor. + """ + total_loss = torch.tensor(0.0, device=outputs[0].device) + for idx, output in enumerate(outputs): + label = labels[:, idx] + weight = self.num_classes[idx] if self.num_classes is not None else 1 + total_loss = total_loss + self.loss_fn(output.squeeze(), label) * weight + + total_weight = sum(self.num_classes) if self.num_classes is not None else len(outputs) + return total_loss / total_weight From 2300bfc491610185f7c6dc6ff6b1448e9b33b470 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 28 Apr 2026 15:55:47 +0000 Subject: [PATCH 6/6] docs: update with custom model availability + multilevel --- README.md | 3 +- docs/source/architecture/overview.md | 92 ++++++++ docs/source/tutorials/custom_model.md | 215 +++++++++++++++++++ docs/source/tutorials/index.md | 19 ++ torchTextClassifiers/torchTextClassifiers.py | 52 ++++- 5 files changed, 374 insertions(+), 7 deletions(-) create mode 100644 docs/source/tutorials/custom_model.md diff --git a/README.md b/README.md index b8d4767..571e666 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ A unified, extensible framework for text classification with categorical variabl - **Unified yet highly customizable**: - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer. - Text embedding is split into two composable stages: **`TokenEmbedder`** (token → per-token vectors, with optional self-attention) and **`SentenceEmbedder`** (aggregation: mean / first / last / label attention). Combine them with `CategoricalVariableNet` and `ClassificationHead` — all are `torch.nn.Module`. - - The `TextClassificationModel` class assembles these components and can be extended for custom behavior. + - **Two architecture paths**: use `ModelConfig` + the `torchTextClassifiers` constructor for the standard `TextClassificationModel` (zero boilerplate), or build any `nn.Module` you like and pass it to `torchTextClassifiers.from_model()` for full control. The `contrib` sub-package ships ready-made custom architectures (e.g. `MultiLevelTextClassificationModel` for multi-task classification) as reference implementations. - **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks. - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code: @@ -56,6 +56,7 @@ See the [examples/](examples/) directory for: - Mixed features (text + categorical) - Advanced training configurations - Prediction and explainability +- [Multi-level classification](examples/multilevel_example.py) — custom architecture via `from_model` and `contrib` ## šŸ“„ License diff --git a/docs/source/architecture/overview.md b/docs/source/architecture/overview.md index 737fa28..e0b5e84 100644 --- a/docs/source/architecture/overview.md +++ b/docs/source/architecture/overview.md @@ -599,6 +599,98 @@ predictions = classifier.predict(new_texts) - Don't need custom architecture - Want simplicity over control +## Two Architecture Paths + +torchTextClassifiers offers two ways to build a classifier, covering different +levels of customisation: + +### Path 1 — Standard architecture (ModelConfig) + +The `torchTextClassifiers` constructor accepts a `ModelConfig` and builds a +`TextClassificationModel` for you automatically. This covers the vast majority +of use cases: single-task binary, multi-class, or multi-label classification, +with or without categorical variables, with or without self-attention and label +attention. + +```python +from torchTextClassifiers import torchTextClassifiers, ModelConfig + +classifier = torchTextClassifiers( + tokenizer=tokenizer, + model_config=ModelConfig(embedding_dim=128, num_classes=5), +) +classifier.train(texts, labels, training_config) +predictions = classifier.predict(new_texts) +``` + +You never instantiate `TextClassificationModel` directly; `ModelConfig` is the +only knob you need. + +### Path 2 — Custom architecture (from_model) + +When `TextClassificationModel` cannot express what you need — multiple +classification heads, shared encoders across tasks, or any other topology — +build your own `nn.Module` and wrap it with `torchTextClassifiers.from_model`. +The wrapper then provides the same `predict` / `save` / `load` interface around +your model. + +```python +import torch.nn as nn +from torchTextClassifiers import torchTextClassifiers + +class MyModel(nn.Module): + num_classes = 3 + categorical_variable_net = None # or a CategoricalVariableNet instance + + def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): + ... + return logits # (batch, num_classes) — raw logits, not softmaxed + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=MyModel(), +) +``` + +**Required interface for custom models:** + +| Requirement | Details | +|---|---| +| `forward(input_ids, attention_mask, categorical_vars=None, **kwargs)` | Exact positional names; extra kwargs are ignored | +| Returns raw logits | `torch.Tensor` of shape `(batch, num_classes)`, or `list[torch.Tensor]` for multi-task | +| `num_classes` attribute | `int` for single-task; `list[int]` for multi-task | +| `categorical_variable_net` attribute | A `CategoricalVariableNet` instance, or `None` | + +### contrib — reference custom architectures + +The `torchTextClassifiers.contrib` sub-package ships example architectures that +follow the `from_model` interface and can be used directly or as starting points: + +| Class | Purpose | +|---|---| +| `MultiLevelTextClassificationModel` | Multi-task classifier: one shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task | +| `MultiLevelCrossEntropyLoss` | Weighted cross-entropy averaged across tasks | + +```python +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=[se_level1, se_level2, se_level3], + classification_heads=[head1, head2, head3], + categorical_variable_net=cat_net, +) +classifier = torchTextClassifiers.from_model(tokenizer=tokenizer, pytorch_model=model) +``` + +See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py) +for a complete working script. + +--- + ## For Advanced Users ### Direct PyTorch Usage diff --git a/docs/source/tutorials/custom_model.md b/docs/source/tutorials/custom_model.md new file mode 100644 index 0000000..c8a505a --- /dev/null +++ b/docs/source/tutorials/custom_model.md @@ -0,0 +1,215 @@ +# Custom Architectures with from_model + +**Difficulty:** Advanced | **Time:** 30 minutes + +## When to use this + +The standard `torchTextClassifiers` constructor + `ModelConfig` covers most +single-task classification needs. Use `from_model` when you need something the +standard architecture cannot express: + +- **Multiple classification heads** (multi-task / hierarchical labels) +- **Shared encoders** across several outputs +- **Custom combination logic** between text and categorical embeddings +- **Any other topology** that does not fit a single linear pipeline + +--- + +## Required interface + +Your custom model must satisfy three contracts so that the wrapper's `predict`, +`save`, and `load` methods work correctly. + +### 1. `forward` signature + +```python +def forward( + self, + input_ids: torch.Tensor, # (batch, seq_len) — Long + attention_mask: torch.Tensor, # (batch, seq_len) — int + categorical_vars: torch.Tensor, # (batch, n_cats) — Long, may be None + **kwargs, # ignored by the wrapper +) -> torch.Tensor | list[torch.Tensor]: + ... +``` + +- The argument **names must match exactly** — the wrapper calls the model with + keyword arguments from the dataloader collate function. +- The return value must be **raw logits** (not softmaxed). + - Single task → `torch.Tensor` of shape `(batch, num_classes)` + - Multi-task → `list[torch.Tensor]`, one tensor per task + +### 2. `num_classes` attribute + +```python +model.num_classes # int (single task) +model.num_classes # list[int] (multi-task — one entry per task head) +``` + +### 3. `categorical_variable_net` attribute + +```python +model.categorical_variable_net # CategoricalVariableNet | None +``` + +Set this to `None` if your model does not use categorical features. When it is +not `None` the wrapper reads +`categorical_variable_net.categorical_vocabulary_sizes` to configure data +encoding. + +--- + +## Minimal example — single-task custom model + +```python +import torch +import torch.nn as nn +from torchTextClassifiers import torchTextClassifiers +from torchTextClassifiers.model.components import TokenEmbedder, TokenEmbedderConfig +from torchTextClassifiers.tokenizers import WordPieceTokenizer + +class MyClassifier(nn.Module): + def __init__(self, vocab_size: int, num_classes: int): + super().__init__() + self.token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=vocab_size, embedding_dim=64, padding_idx=0, + )) + self.pool = lambda x, mask: (x * mask.unsqueeze(-1)).sum(1) / mask.sum(1, keepdim=True) + self.head = nn.Linear(64, num_classes) + + # Required attributes + self.num_classes = num_classes + self.categorical_variable_net = None # no categorical features + + def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs): + out = self.token_embedder(input_ids, attention_mask) + sentence = self.pool(out["token_embeddings"], attention_mask.float()) + return self.head(sentence) # (batch, num_classes) — raw logits + +tokenizer = WordPieceTokenizer(vocab_size=5000) +tokenizer.train(texts) + +model = MyClassifier(vocab_size=tokenizer.vocab_size, num_classes=3) + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, +) +classifier.train(texts, labels, training_config) +predictions = classifier.predict(new_texts) +``` + +--- + +## Multi-task example — contrib architecture + +For multi-task classification the `contrib` sub-package provides ready-made +classes that follow the interface above. + +```python +from torchTextClassifiers import torchTextClassifiers, TrainingConfig +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) +from torchTextClassifiers.model.components import ( + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, SentenceEmbedderConfig, + TokenEmbedder, TokenEmbedderConfig, +) +from torchTextClassifiers.value_encoder import ValueEncoder + +# Assume tokenizer, value_encoder, and model_config are already built. +# value_encoder.num_classes is a list[int] — one count per task level. + +token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=tokenizer.vocab_size, + embedding_dim=64, + padding_idx=tokenizer.padding_idx, +)) +cat_net = CategoricalVariableNet( + categorical_vocabulary_sizes=value_encoder.vocabulary_sizes, + categorical_embedding_dims=8, + text_embedding_dim=64, +) + +sentence_embedders = [] +classification_heads = [] +for n_cls in value_encoder.num_classes: + sentence_embedders.append(SentenceEmbedder(SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig(n_head=2, num_classes=n_cls, embedding_dim=64), + ))) + classification_heads.append(ClassificationHead(input_dim=64 + cat_net.output_dim, num_classes=1)) + +model = MultiLevelTextClassificationModel( + token_embedder=token_embedder, + sentence_embedders=sentence_embedders, + classification_heads=classification_heads, + categorical_variable_net=cat_net, +) + +classifier = torchTextClassifiers.from_model( + tokenizer=tokenizer, + pytorch_model=model, + value_encoder=value_encoder, +) + +training_config = TrainingConfig( + num_epochs=10, + batch_size=32, + lr=1e-3, + raw_categorical_inputs=True, + loss=MultiLevelCrossEntropyLoss(num_classes=list(value_encoder.num_classes)), +) +classifier.train(X_train, y_train, training_config) +predictions = classifier.predict(X_test) +``` + +`predictions` is a dict with one key per task level. + +See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py) +for the full runnable script. + +--- + +## contrib reference + +| Class | Description | +|---|---| +| `MultiLevelTextClassificationModel` | Shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task | +| `MultiLevelCrossEntropyLoss` | Per-task cross-entropy, optionally weighted by `num_classes` | + +```python +from torchTextClassifiers.contrib import ( + MultiLevelTextClassificationModel, + MultiLevelCrossEntropyLoss, +) +``` + +These classes are reference implementations — use them directly or as a +starting point for your own architecture. + +--- + +## Saving and loading + +`save` and `load` work the same way regardless of which path was used. Custom +models are serialised as a pickle of the model structure plus a separate +state-dict file; the `_custom_model` flag in the checkpoint tells `load` which +strategy to use. + +```python +classifier.save("my_classifier/") +loaded = torchTextClassifiers.load("my_classifier/") +``` + +--- + +## Next steps + +- **Architecture overview**: {doc}`../architecture/overview` — component reference and design philosophy +- **API reference**: {doc}`../api/wrapper` — full `torchTextClassifiers` API +- **contrib source**: `torchTextClassifiers/contrib/multilevel.py` diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 845c221..160f53c 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -10,6 +10,7 @@ multiclass_classification mixed_features explainability multilabel_classification +custom_model ``` ## Overview @@ -116,6 +117,21 @@ Assign multiple labels to each text sample for complex classification scenarios. **Difficulty:** Advanced | **Time:** 30 minutes ::: +:::{grid-item-card} {fas}`puzzle-piece` Custom Architectures +:link: custom_model +:link-type: doc + +Plug any PyTorch model into the torchTextClassifiers wrapper via `from_model`. + +**What you'll learn:** +- When to go beyond `TextClassificationModel` +- The required `forward` / `num_classes` / `categorical_variable_net` interface +- Using `contrib` classes as reference implementations +- Multi-task classification with `MultiLevelTextClassificationModel` + +**Difficulty:** Advanced | **Time:** 30 minutes +::: + :::: ## Learning Path @@ -130,6 +146,8 @@ graph LR C --> F[Multilabel Classification] D --> E[Explainability] F --> E + D --> G[Custom Architectures] + F --> G style A fill:#e3f2fd style B fill:#bbdefb @@ -137,6 +155,7 @@ graph LR style D fill:#64b5f6 style E fill:#1976d2 style F fill:#42a5f5 + style G fill:#0d47a1 ``` 1. **Start with**: {doc}`../getting_started/quickstart` - Get familiar with the basics diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 0541c99..3b05477 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -246,17 +246,57 @@ def from_model( value_encoder: Optional[ValueEncoder] = None, ragged_multilabel: Optional[bool] = False, ): - """Initialize torchTextClassifiers from a pre-built PyTorch model. + """Initialize torchTextClassifiers from a custom pre-built PyTorch model. - This method allows users to create a torchTextClassifiers instance using a pre-built PyTorch model that may not follow the standard architecture expected by the main constructor. The provided model should be compatible with the input format used in the predict method (i.e., it should accept tokenized text and categorical variables as input). + Use this when the standard ``TextClassificationModel`` (built automatically + from ``ModelConfig``) cannot express your architecture — for example when you + need multiple classification heads, shared encoders across tasks, or any other + custom topology. The wrapper then provides the usual ``predict`` / ``save`` / + ``load`` interface around your model. + + **Required interface for** ``pytorch_model``: + + 1. **``forward`` signature** — the model must accept exactly these keyword + arguments (extra ``**kwargs`` are forwarded but ignored by the wrapper):: + + def forward( + self, + input_ids: torch.Tensor, # (batch, seq_len) Long + attention_mask: torch.Tensor, # (batch, seq_len) int + categorical_vars: torch.Tensor, # (batch, n_cats) Long — may be None + **kwargs, + ) -> torch.Tensor | list[torch.Tensor]: + ... + + The return value must be **raw logits** (not softmaxed). For standard + single-task classification return a tensor of shape + ``(batch, num_classes)``. For multi-task classification you may return a + list of such tensors, one per task. + + 2. **``num_classes`` attribute** — must be an ``int`` (single task) or a + ``list[int]`` (multi-task, one entry per task head). + + 3. **``categorical_variable_net`` attribute** — the ``CategoricalVariableNet`` + module used by the model, or ``None`` if no categorical features are used. + The wrapper reads ``categorical_variable_net.categorical_vocabulary_sizes`` + to set up the data pipeline. + + See ``torchTextClassifiers.contrib`` for ready-made example architectures + (``MultiLevelTextClassificationModel``, ``MultiLevelCrossEntropyLoss``) that + follow this interface. Args: - tokenizer: A tokenizer instance for text preprocessing - pytorch_model: A pre-built PyTorch model to be used for predictions - value_encoder: Optional ValueEncoder for encoding raw string (or mixed) categorical values to integers. Build it beforehand from DictEncoder or sklearn LabelEncoder instances and pass it here. If None, categorical columns in X must already be integer-encoded. + tokenizer: A tokenizer instance for text preprocessing. + pytorch_model: A pre-built PyTorch model satisfying the interface above. + value_encoder: Optional ``ValueEncoder`` for encoding raw string (or + mixed) categorical values to integers. Build it from ``DictEncoder`` + or sklearn ``LabelEncoder`` instances and pass it here. If ``None``, + categorical columns in ``X`` must already be integer-encoded. + ragged_multilabel: Set to ``True`` for ragged multi-label targets + (variable number of labels per sample). Returns: - An instance of torchTextClassifiers initialized with the provided model and tokenizer. + An instance of torchTextClassifiers wrapping the provided model. """ instance = cls.__new__(cls) instance.tokenizer = tokenizer