โDocument Categorization Prediction with TorchTextโ
๐ง 1. Document Classifier Overviewโ
A document classifier is a model that takes raw text (like news articles) and predicts which category it belongs to โ e.g., sports, business, or science.
- Input: Raw text
- Output: A predicted class label (e.g., โsportsโ)
- How: Converts text โ numbers โ applies neural network โ outputs classification
๐ข 2. Neural Network Basicsโ
A neural network is a function made up of layers of connected โneuronsโ (really just numbers and matrix operations). Hereโs how the flow works:
๐น Layers in a Neural Network:โ
-
Input Layer:
Accepts a numeric representation of the text (e.g., bag-of-words or embeddings).
-
Hidden Layers:
Perform matrix multiplication and apply activation functions like ReLU or sigmoid. These layers โlearnโ internal features.
-
Output Layer:
Produces a vector of logits โ one number per possible class.
๐น Logits:โ
- Raw, unnormalized scores (can be negative or positive).
- Not probabilities โ just signals for classification.
๐น Argmax Function:โ
- Applied to the output logits.
- Returns the index of the largest value โ this corresponds to the predicted class.
โ๏ธ 3. Hyperparametersโ
Hyperparameters are configurations you set manually when designing a neural network. Common ones include:
Hyperparameter | Description |
---|---|
Number of hidden layers | Depth of the network |
Neurons per layer | Width or complexity per layer |
Embedding dimension | Size of the dense word vector |
Number of output classes | Equals the number of possible categories (e.g., 4 for news: world, sports, business, tech) |
These are not learned โ you tune them using validation data.
๐งฐ 4. PyTorch Implementation Overviewโ
๐น Dataset: AG News Datasetโ
- Each row = (label, text)
- Labels are mapped to categories like:
- 0 = World
- 1 = Sports
- 2 = Business
- 3 = Science & Tech
๐น Processing Pipeline:โ
- Tokenization โ Convert raw text into tokens (e.g., โI like catsโ โ โIโ,โlikeโ,โcatsโโIโ, โlikeโ, โcatsโโIโ,โlikeโ,โcatsโ).
- Vocabulary โ Each token gets an index.
- Indexing โ Tokens in text are replaced with their corresponding index.
- Offsets โ Track where each document starts in a flattened tensor.
๐งฑ 5. Model Architectureโ
๐ธ 1. Embedding Bag Layerโ
- Similar to
nn.Embedding
but directly aggregates multiple tokens into one vector (by summing or averaging). - Input: token indices and offsets
- Output: single vector for the document
๐ธ 2. Fully Connected (Linear) Layerโ
- Maps the aggregated vector to output logits (one per category).
python
CopyEdit
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
๐งช 6. Prediction Workflowโ
- Input: Tokenized and indexed text + offset
- Pass through embedding bag โ get dense representation
- Pass through fully connected layer โ get logits
- Apply argmax โ get predicted class
python
CopyEdit
output = model(text_tensor, offset_tensor)
prediction = torch.argmax(output, dim=1)
๐ 7. Batchingโ
- Batching is used to process multiple documents at once.
- Use PyTorchโs
DataLoader
to batch inputs. - Create a batch function that:
- Flattens text indices from all samples into one tensor.
- Adds an offset to mark where each sample starts.
โ Recap Summaryโ
Concept | Explanation |
---|---|
Neural Network | Transforms numeric input into a classification through layers of weights and activations |
Embedding Bag | Maps word indices to dense vectors and aggregates them for full documents |
Logits | Raw output scores (one per class), used before applying argmax |
Argmax | Selects the index of the highest logit to predict the documentโs class |
Hyperparameters | Settings like number of layers, neurons, and embedding size, manually tuned |
Text Pipeline | Tokenization โ Indexing โ Offsets โ Embedding โ Classification |
Batching | Combines multiple samples for efficient processing with offsets to track each |