1 minute read

Implementing Linear Regression in PyTorch

This guide introduces customizing nn.Module for linear regression, showing the difference between Python and R programming styles.

For statisticians who are already familiar with R, implementing linear regression is often as simple as calling lm(y ~ x, data), via R’s built-in statistical modeling functions, e.g.,

model <- lm(y ~ x, data)
coef(model)
y_pred <- predict(model, newdata = data.frame(x = c(1, 2, 3, 4, 5)))
mse <- mean((y - y_pred)^2)

However, in deep learning frameworks like PyTorch, the process requires a more explicit definition of model architecture, parameters, and computational graph.

import torch
import torch.nn as nn
class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        """
        Custom PyTorch module for Linear Regression.
        
        Args:
            input_dim (int): Number of input features.
            output_dim (int): Number of output features (typically 1 for regression).
        """
        super(LinearRegression, self).__init__()

        # Single-layer linear transformation
        self.linear = nn.Linear(input_dim, output_dim)  
    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (Tensor): Input tensor of shape (batch_size, input_dim).
        
        Returns:
            Tensor: Predicted output of shape (batch_size, output_dim).
        """
        return self.linear(x)

# Example Usage:
if __name__ == "__main__":
    # Define model for 1D input to 1D output
    model = LinearRegression(input_dim=1, output_dim=1)
    
    # Sample input tensor (batch of 3 samples)
    x_sample = torch.tensor([[1.0], [2.0], [3.0]])
    
    # Forward pass
    y_pred = model(x_sample)
    
    print("Predicted Output:\n", y_pred)

Tags:

Updated:

BACK TO TOP