Product

Tuesday, 17 March 2026

A Gentle Introduction to Graph Neural Networks in Python

 

A Gentle Introduction to Graph Neural Networks in Python

A Gentle Introduction to Graph Neural Networks in Python

Introduction

Graph neural networks (GNNs) can be pictured as a special class of neural network models where data are structured as graphs — both training data used to train the model and real-world data used for inference — rather than fixed-size vectors or grids like image, sequences, or instances of tabular data.

While conventional neural network architectures like feed-forward models excel in modeling predictive problems like classification on structured, tabular data or images, GNNs are designed to accommodate problems where the relationships between data entities are complex and irregular. Take for instance social networks, molecular structures, and knowledge graphs. Like in any graph, the input data used for training and inference in GNNs is represented as a graph, with nodes representing entities (e.g. users in a social network) and edges representing relationships (e.g. friendships or follows between users).

Interested in better understanding how GNNs work through a gentle practical example in Python? Then keep reading.

Defining a Graph Neural Network in Python

In this introductory example of building a GNN, we will consider a small graph dataset associated with a social media platform, where each node represents a person and each edge connecting any two nodes is a friendship between persons. Furthermore, each node (person) has associated features like the person’s age, their interests, etc.

The target task of the GNN we will build is classifying people on either popular or not popular in the social network (binary classification), depending on whether having more than two or less than two friends in it, and taking into account:

  1. The person’s features, such as their interests
  2. The person’s connections with other persons

Therefore, GNNs give an extra layer of sophistication to predictive tasks, because they not only look at the target instance’s features to make a prediction but also at its relationship with other data instances, unlike classical classification and regression models.

Without further ado, let’s start coding. We’ll use several PyTorch components suitable for building GNNs, so we start by installing them first:

Now the necessary imports:

This is our “mini-social network” dataset or graph:

Basically, edge_index is a matrix of edges or connections between users. There are 5 users, numbered 0 to 4. The first connection is from user 0 to user 1, and we know this by looking at the first element in each row of the matrix. The second connection is the reciprocal of the previous one: user 1 to user 0. Then comes user 0 to user 2, and so on. User 3 seems not to be connected to anyone yet!

Now we model two numerical features for each person, in a tensor node_features: the person’s age, and their interest in sports, with 1 indicating interest and 0 indicating no interest.

Visualizing a Graph Neural Network in Python

One way to visualize our graph neural network in Python can be accomplished by using the NetworkX library. It will create a graph from the edge list and Matplotlib to display it. An example of this is below.

Visualization of the social network graph

Figure 1: Visualization of the social network graph

Building a Graph Neural Network Model in Python

Now we define labels for the dataset of users, i.e. whether a person is popular or not, based on whether the person has more than 2 friends or not. The process entails calculating the number of friends of each person (ground truth) based on the adjacency matrix.

Using the following mask, we will indicate that the first three people will be used as training data to build the GNN, and the other two will be used later for inference. Finally, we also wrap everything into a Data object.

The next piece of code is crucial. It defines the GNN architecture and instantiates the model. In PyTorch, GNN models can be built by using graph convolutional layers, such as the ones implemented by the GCNConv class in torch_geometric.nn. Graph convolutional layers aggregate information from a node’s neighbors, helping learn representations that capture not only node features but also structural relationships in the graph.

Training a Graph Neural Network in Python

The training model is reasonably similar to training other types of neural network models in PyTorch:

Sample training output:

Graph Neural Network Inference in Python

Once the GNN has been trained, the inference process is straightforward. We pass the full dataset to calculate popularity predictions, including the two users that were not seen during training, and print the results. Notice that the argmax function is used to obtain the class with the highest probability for each user, from among the two available classes: this is the essence of binary classifiers like logistic regressors.

This is the resulting list of predictions:

So, we can see that all users are deemed popular except user 3, a.k.a. the “lonely user.”

Wrapping Up

To sum up, we have built a very simple GNN that uses a graph representation of a dataset to perform predictions based not only on the features of instances (represented by nodes) but also by looking at the relationships or connections with other instances.

Connect broadband

A Gentle Introduction to Graph Neural Networks in Python

  Share A Gentle Introduction to Graph Neural Networks in Python Introduction Graph neural networks (GNNs)  can be pictured as a special cla...