Introduction

Long Short-Term Memory (LSTM) networks are a type of Recurrent Neural Network (RNN) designed to handle the vanishing gradient problem, which is common in traditional RNNs. LSTMs are particularly effective for tasks involving sequential data, such as time series forecasting, natural language processing, and speech recognition.

Key Concepts

  1. Cell State: The cell state acts as a conveyor belt, running through the entire chain with minor linear interactions, allowing information to flow unchanged.
  2. Gates: LSTMs use gates to control the flow of information. There are three types of gates:
    • Forget Gate: Decides what information to discard from the cell state.
    • Input Gate: Decides which values from the input to update the cell state.
    • Output Gate: Decides what part of the cell state to output.

LSTM Architecture

Forget Gate

The forget gate decides what information to throw away from the cell state. It uses a sigmoid function to output a number between 0 and 1 for each number in the cell state \(C_{t-1}\).

import torch
import torch.nn as nn

# Example of forget gate
input_size = 10
hidden_size = 20

x_t = torch.randn(1, input_size)  # Input at time t
h_t_minus_1 = torch.randn(1, hidden_size)  # Hidden state at time t-1
C_t_minus_1 = torch.randn(1, hidden_size)  # Cell state at time t-1

forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
combined = torch.cat((x_t, h_t_minus_1), 1)
f_t = torch.sigmoid(forget_gate(combined))

print(f"Forget gate output: {f_t}")

Input Gate

The input gate updates the cell state with new information. It consists of two parts: a sigmoid layer and a tanh layer.

# Example of input gate
input_gate = nn.Linear(input_size + hidden_size, hidden_size)
input_transform = nn.Linear(input_size + hidden_size, hidden_size)

i_t = torch.sigmoid(input_gate(combined))
C_tilde_t = torch.tanh(input_transform(combined))

print(f"Input gate output: {i_t}")
print(f"Input transform output: {C_tilde_t}")

Cell State Update

The cell state is updated using the forget gate and the input gate.

# Update cell state
C_t = f_t * C_t_minus_1 + i_t * C_tilde_t

print(f"Updated cell state: {C_t}")

Output Gate

The output gate decides what the next hidden state should be. It uses a sigmoid function and the updated cell state.

# Example of output gate
output_gate = nn.Linear(input_size + hidden_size, hidden_size)
o_t = torch.sigmoid(output_gate(combined))
h_t = o_t * torch.tanh(C_t)

print(f"Output gate output: {o_t}")
print(f"Next hidden state: {h_t}")

Building an LSTM from Scratch

Step-by-Step Implementation

  1. Define the LSTM Class: Create a class that inherits from nn.Module.
  2. Initialize Layers: Define the forget gate, input gate, and output gate layers.
  3. Forward Pass: Implement the forward pass to compute the hidden state and cell state.
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.input_transform = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
    
    def forward(self, x, h_prev, C_prev):
        combined = torch.cat((x, h_prev), 1)
        
        f_t = torch.sigmoid(self.forget_gate(combined))
        i_t = torch.sigmoid(self.input_gate(combined))
        C_tilde_t = torch.tanh(self.input_transform(combined))
        C_t = f_t * C_prev + i_t * C_tilde_t
        
        o_t = torch.sigmoid(self.output_gate(combined))
        h_t = o_t * torch.tanh(C_t)
        
        return h_t, C_t

# Example usage
input_size = 10
hidden_size = 20
lstm = LSTM(input_size, hidden_size)

x_t = torch.randn(1, input_size)
h_t_minus_1 = torch.randn(1, hidden_size)
C_t_minus_1 = torch.randn(1, hidden_size)

h_t, C_t = lstm(x_t, h_t_minus_1, C_t_minus_1)
print(f"Next hidden state: {h_t}")
print(f"Next cell state: {C_t}")

Practical Exercise

Task

Create an LSTM network to predict the next value in a simple time series.

Steps

  1. Generate Data: Create a simple sine wave dataset.
  2. Define the LSTM Model: Use PyTorch's nn.LSTM module.
  3. Train the Model: Implement the training loop.
  4. Evaluate the Model: Predict and plot the results.

Solution

import numpy as np
import matplotlib.pyplot as plt

# Generate sine wave data
time_steps = np.linspace(0, 100, 1000)
data = np.sin(time_steps)

# Prepare data for LSTM
def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+tw:i+tw+1]
        inout_seq.append((train_seq, train_label))
    return inout_seq

train_window = 10
train_inout_seq = create_inout_sequences(data, train_window)

# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=100, output_size=1):
        super(LSTMModel, self).__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(1,1,self.hidden_layer_size),
                            torch.zeros(1,1,self.hidden_layer_size))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq) ,1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]

model = LSTMModel()
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 150
for i in range(epochs):
    for seq, labels in train_inout_seq:
        optimizer.zero_grad()
        model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
                        torch.zeros(1, 1, model.hidden_layer_size))

        y_pred = model(seq)

        single_loss = loss_function(y_pred, labels)
        single_loss.backward()
        optimizer.step()

    if i%25 == 1:
        print(f'epoch: {i:3} loss: {single_loss.item():10.8f}')

# Predict and plot
fut_pred = 100
test_inputs = data[-train_window:].tolist()

model.eval()

for i in range(fut_pred):
    seq = torch.FloatTensor(test_inputs[-train_window:])
    with torch.no_grad():
        model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
                        torch.zeros(1, 1, model.hidden_layer_size))
        test_inputs.append(model(seq).item())

x = np.arange(1000, 1000+fut_pred, 1)
plt.title('Sine Wave Prediction')
plt.grid(True)
plt.plot(time_steps, data, label='True Data')
plt.plot(x, test_inputs[train_window:], label='Predictions')
plt.legend()
plt.show()

Summary

In this section, we covered the architecture and functionality of Long Short-Term Memory (LSTM) networks. We explored the different gates that control the flow of information and implemented an LSTM from scratch. Finally, we applied an LSTM to a practical time series prediction task. This knowledge prepares you for more complex sequential data tasks and sets the foundation for understanding advanced RNN architectures.

© Copyright 2024. All rights reserved