Direkt zum Inhalt

Eine umfassende Einführung in Graph Neural Networks (GNNs)

Erfahre alles über Graph Neural Networks (GNNs), was sie sind, die verschiedenen Arten von Graph Neural Networks und wofür sie verwendet werden. Außerdem erfährst du, wie du mit Pytorch ein Graph Neural Network erstellst.
Aktualisierte 16. Jan. 2025  · 15 Min. Lesezeit

GNN

Was ist eine Grafik?

Ein Graph ist eine Art Datenstruktur, die Knoten und Kanten enthält. Ein Knoten kann eine Person, ein Ort oder eine Sache sein, und die Kanten definieren die Beziehung zwischen den Knoten. Die Kanten können aufgrund von Richtungsabhängigkeiten gerichtet und ungerichtet sein. 

In dem Beispiel unten sind die blauen Kreise Knoten und die Pfeile sind Kanten. Die Richtung der Kanten definiert Abhängigkeiten zwischen zwei Knotenpunkten. 

Abhängigkeiten zwischen zwei Knotenpunkten

Bild vom Autor 

Lernen wir den komplexen Graph-Datensatz kennen: Jazz Musicians Network. Sie enthält 198 Knoten und 2742 Kanten. Im folgenden Community-Diagramm stehen die verschiedenfarbigen Knoten für verschiedene Gemeinschaften von Jazzmusikern und die Kanten, die sie verbinden. Es gibt ein Netz der Zusammenarbeit, in dem ein einzelner Musiker Beziehungen innerhalb und außerhalb der Gemeinschaft hat. 

Community Graph

 Community Graph Plot von Jazz Musicians Network

Diagramme eignen sich hervorragend, um komplexe Probleme mit Beziehungen und Wechselwirkungen zu behandeln. Sie werden in der Mustererkennung, der Analyse sozialer Netzwerke, in Empfehlungssystemen und der semantischen Analyse eingesetzt. Das Erstellen von graphenbasierten Lösungen ist ein ganz neues Feld, das reiche Einblicke in komplexe und vernetzte Datensätze bietet.

Diagramme mit NetworkX

In diesem Abschnitt lernen wir, wie man mit NetworkX einen Graphen erstellt. 

Der folgende Code wurde von Daniel Holmbergs Blog über Graph Neural Networks in Python beeinflusst.

  1. Erstelle das DiGraph-Objekt "H" von networkx
  2. Knoten hinzufügen, die unterschiedliche Beschriftungen, Farben und Größen enthalten
  3. Füge Kanten hinzu, um eine Beziehung zwischen zwei Knotenpunkten herzustellen. Zum Beispiel bedeutet "(0,1)", dass 0 eine Richtungsabhängigkeit von 1 hat. Wir werden bidirektionale Beziehungen erstellen, indem wir "(1,0)" hinzufügen.
  4. Farben und Größen in Form von Listen extrahieren
  5. Zeichne das Diagramm mit der Zeichenfunktion von networkx
import networkx as nx
H = nx.DiGraph()

#adding nodes
H.add_nodes_from([
  (0, {"color": "blue", "size": 250}),

  (1, {"color": "yellow", "size": 400}),

  (2, {"color": "orange", "size": 150}),

  (3, {"color": "red", "size": 600})


])

#adding edges
H.add_edges_from([
  (0, 1),

  (1, 2),

  (1, 0),

  (1, 3),

  (2, 3),

  (3,0)


])

node_colors = nx.get_node_attributes(H, "color").values()
colors = list(node_colors)
node_sizes = nx.get_node_attributes(H, "size").values()
sizes = list(node_sizes)

#Plotting Graph
nx.draw(H, with_labels=True, node_color=colors, node_size=sizes)

ungerichteter Graph

Im nächsten Schritt werden wir die Datenstruktur mit der Funktion to_undirected() von einem gerichteten in einen ungerichteten Graphen umwandeln. 

#converting to undirected graph
G = H.to_undirected()
nx.draw(G, with_labels=True, node_color=colors, node_size=sizes)

Warum ist es schwierig, eine Grafik zu analysieren?

Graphenbasierte Datenstrukturen haben Nachteile, und Datenwissenschaftler müssen sie verstehen, bevor sie graphenbasierte Lösungen entwickeln.

  1. Ein Graph existiert im nicht-euklidischen Raum. Sie existiert nicht im 2D- oder 3D-Raum, was die Interpretation der Daten erschwert. Um die Struktur im 2D-Raum zu visualisieren, musst du verschiedene Werkzeuge zur Dimensionalitätsreduktion verwenden.
  2. Diagramme sind dynamisch; sie haben keine feste Form. Es kann zwei visuell unterschiedliche Graphen geben, die aber eine ähnliche Adjazenzmatrix haben können. Das macht es uns schwer, die Daten mit traditionellen statistischen Instrumenten zu analysieren. 
  3. Eine große Größe und Dimensionalität erhöht die Komplexität des Graphen für menschliche Interpretationen. Die dichte Struktur mit vielen Knoten und Tausenden von Kanten ist schwieriger zu verstehen und Erkenntnisse zu gewinnen. 

Was ist ein Graph Neural Network (GNN)?

Graphenneuronale Netze sind spezielle Arten von neuronalen Netzen, die mit einer Graphen-Datenstruktur arbeiten können. Sie werden stark von Convolutional Neural Networks (CNNs) und Graph Embedding beeinflusst. GNNs werden bei der Vorhersage von Knoten, Kanten und graphenbasierten Aufgaben eingesetzt. 

  • CNNs werden zur Bildklassifizierung eingesetzt. In ähnlicher Weise werden GNNs auf eine Graphenstruktur (Gitter aus Pixeln) angewendet, um eine Klasse vorherzusagen. 
  • Neuronale Netze mit Rekursion werden bei der Textklassifizierung eingesetzt. Ähnlich werden GNNs auf Graphenstrukturen angewandt, bei denen jedes Wort ein Knoten in einem Satz ist.  

GNNs wurden eingeführt, als Convolutional Neural Networks aufgrund der willkürlichen Größe des Graphen und der komplexen Struktur keine optimalen Ergebnisse erzielen konnten. 

Neuronales Netz

Bild von Purvanshi Mehta

Der Eingangsgraph wird durch eine Reihe von neuronalen Netzen geleitet. Die eingegebene Graphenstruktur wird in eine Grapheneinbettung umgewandelt, die es uns ermöglicht, Informationen über Knoten, Kanten und den globalen Kontext zu erhalten. 

Dann wird der Merkmalsvektor der Knoten A und C durch die Schicht des neuronalen Netzes geleitet. Sie fasst diese Merkmale zusammen und gibt sie an die nächste Ebene weiter - neptune.ai.

Lies unser Deep Learning-Tutorial oder besuche unseren Kurs Einführung in Deep Learning, um mehr über Deep Learning-Algorithmen und -Anwendungen zu erfahren. 

Arten von Graph Neural Networks

Es gibt verschiedene Arten von neuronalen Netzen, und die meisten von ihnen haben eine Variante von Convolutional Neural Networks. In diesem Abschnitt werden wir uns mit den beliebtesten GNNs beschäftigen. 

  • Graph Convolutional Networks (GCNs ) sind den traditionellen CNNs ähnlich. Es lernt Merkmale, indem es die benachbarten Knotenpunkte untersucht. GNNs aggregieren Knotenvektoren, leiten das Ergebnis an die dichte Schicht weiter und wenden die Nichtlinearität mithilfe der Aktivierungsfunktion an. Kurz gesagt, besteht sie aus einer Graph-Faltung, einer linearen Schicht und einer Aktivierungsfunktion ohne Lernfunktion. Es gibt zwei Haupttypen von GCNs: Spatial Convolutional Networks und Spectral Convolutional Networks.
  • Graphen-Auto-Encoder Netzwerke lernen die Graphenrepräsentation mithilfe eines Encoders und versuchen, die Eingabegraphen mithilfe eines Decoders zu rekonstruieren. Der Encoder und der Decoder sind durch eine Engpassschicht verbunden. Sie werden häufig in der Linkvorhersage verwendet, da Auto-Encoder gut mit Klassengleichheit umgehen können. 
  • Recurrent Graph Neural Networks (RGNNs) lernen das beste Diffusionsmuster und können mit multirelationalen Graphen umgehen, bei denen ein einzelner Knoten mehrere Beziehungen hat. Diese Art von neuronalen Graphen-Netzwerken verwendet Regularizer, um die Glätte zu erhöhen und eine Überparametrisierung zu vermeiden. RGNNs brauchen weniger Rechenleistung, um bessere Ergebnisse zu erzielen. Sie werden bei der Texterstellung, der maschinellen Übersetzung, der Spracherkennung, der Erstellung von Bildbeschreibungen, dem Tagging von Videos und der Textzusammenfassung eingesetzt.
  • Gated Graph Neural Networks (GGNNs ) sind besser als RGNNs, wenn es um Aufgaben mit langfristigen Abhängigkeiten geht. Gated Graph Neural Networks verbessern rekurrente neuronale Netze, indem sie Knoten-, Kanten- und Zeitgatter für langfristige Abhängigkeiten hinzufügen. Ähnlich wie bei Gated Recurrent Units (GRUs) werden die Gates verwendet, um Informationen in verschiedenen Zuständen zu erinnern und zu vergessen. 

Wenn du mehr über rekurrente neuronale Netze (RNNs) erfahren möchtest, schau dir den Kurs von DataCamp an. Es wird dich in verschiedene RNN-Modellarchitekturen, Keras-Frameworks und RNN-Anwendungen einführen. 

Arten von Graph Neural Networks Aufgaben

Im Folgenden haben wir einige Arten von GNN-Aufgaben mit Beispielen beschrieben:

  • Klassifizierung von Graphen: Wir verwenden dies, um Graphen in verschiedene Kategorien zu klassifizieren. Seine Anwendungen sind die Analyse sozialer Netzwerke und die Klassifizierung von Texten. 
  • Knotenklassifizierung: Diese Aufgabe nutzt die Kennzeichnungen benachbarter Knoten, um fehlende Knotenkennzeichnungen in einem Graphen vorherzusagen. 
  • Link Prediction: sagt die Verbindung zwischen einem Knotenpaar in einem Graphen mit einer unvollständigen Adjazenzmatrix voraus. Es wird häufig für soziale Netzwerke verwendet. 
  • Community Detection: teilt die Knoten anhand der Kantenstruktur in verschiedene Cluster ein. Es lernt von Kantengewichten, Entfernungen und Graphenobjekten auf ähnliche Weise. 
  • Graph Embedding: bildet Graphen in Vektoren ab, wobei die relevanten Informationen über Knoten, Kanten und Struktur erhalten bleiben.
  • Graphengenerierung: Lernt aus der Verteilung von Beispielgraphen, um eine neue, aber ähnliche Graphenstruktur zu erzeugen. 

Arten von Graph Neural Networks

Bild vom Autor

Nachteile von Graph Neural Networks

Die Verwendung von GNNs hat ein paar Nachteile. Wenn wir sie verstehen, können wir entscheiden, wann wir GNNa einsetzen und wie wir die Leistung unserer maschinellen Lernmodelle optimieren können. 

  1. Die meisten neuronalen Netze können tief gehen, um eine bessere Leistung zu erzielen, während GNNs flache Netze mit meist drei Schichten sind. Dadurch können wir bei großen Datensätzen keine Spitzenleistung erzielen.
  2. Die Graphenstrukturen ändern sich ständig, was es schwieriger macht, ein Modell darauf zu trainieren. 
  3. Beim Einsatz des Modells in der Produktion gibt es Probleme mit der Skalierbarkeit, da diese Netzwerke sehr rechenintensiv sind. Wenn du eine große und komplexe Graphenstruktur hast, wird es für dich schwierig sein, die GNNs in der Produktion zu skalieren. 

Was ist ein Graph Convolutional Network (GCN)?

Die meisten GNNs sind Graph Convolutional Networks, und es ist wichtig, etwas über sie zu erfahren, bevor du dich in ein Tutorial zur Knotenklassifizierung stürzt.  

Die Faltung in GCN ist dasselbe wie die Faltung in faltigen neuronalen Netzen. Es multipliziert Neuronen mit Gewichten (Filtern), um aus Datenmerkmalen zu lernen. 

Sie fungiert als gleitende Fenster auf ganzen Bildern, um Merkmale aus benachbarten Zellen zu lernen. Der Filter nutzt Weight Sharing, um verschiedene Gesichtsmerkmale in Bilderkennungssystemen zu lernen - Towards Data Science

Übertrage nun die gleiche Funktionalität auf Graph Convolutional Networks, bei denen ein Modell die Merkmale von benachbarten Knoten lernt. Der Hauptunterschied zwischen GCN und CNN besteht darin, dass es für nicht-euklidische Datenstrukturen entwickelt wurde, bei denen die Reihenfolge der Knoten und Kanten variieren kann. 

CNN vs GCN

CNN vs. GCN | Bildquelle

Erfahre mehr über grundlegende CNNs, indem du dem TensorFlow-Tutorial zu Convolutional Neural Networks (CNN) folgst. 

Es gibt zwei Arten von GCNs: 

  • Spatial Graph Convolutional Networks nutzen räumliche Merkmale, um aus Graphen zu lernen, die sich im Raum befinden.  
  • Spektrale Graphenfaltungsnetze nutzen die Eigenwertzerlegung der Laplacian-Matrix des Graphen für die Informationsausbreitung entlang der Knotenpunkte. Diese Netzwerke wurden von der Wellenausbreitung in Signalen und Systemen inspiriert. 

Wie funktionieren die GNNs? Aufbau eines neuronalen Graphen-Netzwerks mit Pytorch

Wir werden die Spektrale Graphenfaltung für ein Knotenklassifizierungsmodell aufbauen und trainieren. Der Quellcode ist in dieser DataLab-Arbeitsmappe verfügbar, damit du dein erstes graphbasiertes maschinelles Lernmodell ausprobieren und ausführen kannst. 

Die Kodierungsbeispiele sind von der geometrischen Dokumentation von Pytorch beeinflusst. 

Erste Schritte

Wir werden das Pytorch-Paket installieren, da pytorch_geometric darauf aufbaut. 

!pip install -q torch

Dann werden wir die torch-Version verwenden, um torch-scatter und torch-sparse zu installieren. Danach installieren wir die neueste Version von pytorch_geometricvon GitHub. 

%%capture
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['PYTHONWARNINGS'] = "ignore"
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

Planetoid Cora Datensatz

Planetoid ist ein Zitationsnetzwerk-Datensatz aus Cora, CiteSeer und PubMed. Die Knoten sind Dokumente mit 1433-dimensionalen Bag-of-Words-Merkmalsvektoren, und die Kanten sind Zitierlinks zwischen Forschungsarbeiten. Es gibt 7 Klassen und wir werden das Modell trainieren, um fehlende Bezeichnungen vorherzusagen. 

Wir nehmen den Planetoid Cora-Datensatz auf und normalisieren die Bag of Words-Eingangsmerkmale in Zeilen. Danach werden wir den Datensatz und das erste Diagrammobjekt analysieren. 

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.
print(data)

Der Cora-Datensatz hat 2708 Knoten, 10.556 Kanten, 1433 Merkmale und 7 Klassen. Das erste Objekt hat 2708 Trainings-, Validierungs- und Testmasken. Wir werden diese Masken verwenden, um das Modell zu trainieren und zu bewerten. 

Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

Knotenklassifizierung mit GNN

Wir erstellen eine GCN-Modellstruktur, die zwei GCNConv-Schichten mit einer relu Aktivierung und einer Dropout-Rate von 0,5 enthält. Das Modell besteht aus 16 versteckten Kanälen.  

GCN-Schicht:

GCN-Schicht

W(ℓ+1) ist in der obigen Gleichung eine übertragbare Gewichtsmatrix und Cw,v gibt einen festen Normalisierungskoeffizienten für jede Kante vor.

from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)
print(model)

>>> GCN(
    (conv1): GCNConv(1433, 16)
    (conv2): GCNConv(16, 7)
  )

Visualisierung des untrainierten GCN-Netzwerks

Lass uns die Knoteneinbettungen von untrainierten GCN-Netzwerken mit sklearn.manifold.TSNE und matplotlib.pyplot visualisieren. Es wird ein 7-dimensionaler Knoten gezeichnet, der ein 2D-Streudiagramm einbettet.  

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

Wir werten das Modell aus und fügen dem untrainierten Modell Trainingsdaten hinzu, um verschiedene Knoten und Kategorien zu visualisieren. 

model.eval()

out = model(data.x, data.edge_index)
visualize(out, color=data.y)

untrainiertes Modell

Ausbildung GNN

Wir werden unser Modell anhand von 100 Epochen mit der Adam-Optimierung und der Cross-Entropy Loss-Funktion trainieren. 

In der Zugfunktion haben wir:

  1. Lösche den Farbverlauf
  2. Einen einzelnen Vorwärtspass ausgeführt
  3. Berechne den Verlust mit Hilfe von Trainingsknoten
  4. Berechne den Gradienten und aktualisiere die Parameter

In der Testfunktion haben wir:

  1. Vorausgesagte Knotenklasse
  2. Extrahierte Klassenbezeichnung mit der höchsten Wahrscheinlichkeit
  3. Überprüft, wie viele Werte richtig vorhergesagt wurden
  4. Erstellen der Genauigkeitsquote anhand der Summe der richtigen Vorhersagen geteilt durch die Gesamtzahl der Knotenpunkte. 
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = criterion(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)
      test_correct = pred[data.test_mask] == data.y[data.test_mask]
      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
      return test_acc


for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=8)
)

.. .. .. ..
.. .. .. ..
Epoch: 098, Loss: 0.5989
Epoch: 099, Loss: 0.6021
Epoch: 100, Loss: 0.5799

Modellbewertung

Wir werden das Modell nun an einem ungesehenen Datensatz mit der Testfunktion testen. Wie du siehst, haben wir mit 81,5 % Genauigkeit ziemlich gute Ergebnisse erzielt.  

test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')


>>> Test Accuracy: 0.8150

Um die Ergebnisse zu überprüfen, werden wir nun das Output Embedding eines trainierten Modells visualisieren.

model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)

Wie wir sehen können, hat das trainierte Modell eine bessere Clusterung der Knoten für dieselbe Kategorie ergeben.

Clustering von Knotenpunkten

GATConv-Modell trainieren

Im zweiten Schritt werden wir GCNConv durch GATConv-Schichten ersetzen. Das Graph Attention Networks verwendet maskierte Selbstbeobachtungsschichten, um die Nachteile von GCNConv zu beheben und modernste Ergebnisse zu erzielen. 

Du kannst auch andere GNN-Schichten ausprobieren und mit Optimierungen, Aussetzern und der Anzahl versteckter Kanäle herumspielen, um eine bessere Leistung zu erzielen. 

Im folgenden Code haben wir gerade GCNConv durch GATConv mit 8 Aufmerksamkeitsköpfen in der ersten Schicht und 1 in der zweiten Schicht ersetzt. 

Wir werden auch einstellen:

  • Abbrecherquote auf 0,6
  • versteckte Kanäle auf 8
  • Lernrate 0,005

Wir haben die Testfunktion geändert, um die Genauigkeit einer bestimmten Maske (valid, test) zu ermitteln. Sie hilft uns, die Validierungs- und Testergebnisse während des Modelltrainings auszudrucken. Außerdem speichern wir die Validierungs- und Testergebnisse später in einem Liniendiagramm. 

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(dataset.num_features, hidden_channels,heads)
        self.conv2 = GATConv(heads*hidden_channels, dataset.num_classes,heads)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=8, heads=8)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = criterion(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)
      correct = pred[mask] == data.y[mask]
      acc = int(correct.sum()) / int(mask.sum())
      return acc

val_acc_all = []
test_acc_all = []

for epoch in range(1, 101):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    val_acc_all.append(val_acc)
    test_acc_all.append(test_acc)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

.. .. .. ..
.. .. .. ..
Epoch: 098, Loss: 1.1283, Val: 0.7960, Test: 0.8030

Epoch: 099, Loss: 1.1352, Val: 0.7940, Test: 0.8050

Epoch: 100, Loss: 1.1053, Val: 0.7960, Test: 0.8040

Wie wir feststellen können, hat unser Modell nicht besser abgeschnitten als GCNConv. Es erfordert eine Optimierung der Hyperparameter oder mehr Epochen, um die besten Ergebnisse zu erzielen. 

Modellbewertung

Im Auswertungsteil visualisieren wir die Validierungs- und Testergebnisse mit dem Liniendiagramm von matplotlib.pyplot .  

import numpy as np

plt.figure(figsize=(12,8))
plt.plot(np.arange(1, len(val_acc_all) + 1), val_acc_all, label='Validation accuracy', c='blue')
plt.plot(np.arange(1, len(test_acc_all) + 1), test_acc_all, label='Testing accuracy', c='red')
plt.xlabel('Epochs')
plt.ylabel('Accurarcy')
plt.title('GATConv')
plt.legend(loc='lower right', fontsize='x-large')
plt.savefig('gat_loss.png')
plt.show()

Nach 60 Epochen hat die Validierungs- und Prüfgenauigkeit einen stabilen Wert von 0,8+/-0,02 erreicht. 

GATConv-Modell

Lass uns noch einmal das Knoten-Cluster des GATConv-Modells veranschaulichen.

model.eval()

out = model(data.x, data.edge_index)
visualize(out, color=data.y)

Wie wir sehen können, hat die GATConv-Schicht die gleichen Ergebnisse bei der Clusterbildung für dieselbe Kategorie von Knotenpunkten erzielt. 

clustering

Wir können die Überanpassung reduzieren, indem wir einen zweiten Validierungsdatensatz hinzufügen und die Modellleistung verbessern, indem wir mit verschiedenen GCN-Schichten aus pytoch_geometric experimentieren. 

Der Quellcode des Tutorials ist in dieser DataLab-Arbeitsmappe verfügbar. Erstelle eine Kopie der Arbeitsmappe, die du ausführen kannst.

Füge deinem Lebenslauf Deep Learning-Fähigkeiten hinzu, indem du den Lernpfad Deep Learning in Python belegst. Er führt dich in Deep-Learning-Algorithmen, Keras, Pytorch und das Tensorflow-Framework ein. 

FAQs

Wofür werden Graph Neural Networks verwendet?

Graphneuronale Netze werden direkt auf Graphdatensätze angewandt und du kannst sie trainieren, um Knoten, Kanten und graphenbezogene Aufgaben vorherzusagen. Es wird für die Klassifizierung von Graphen und Knoten, Link-Vorhersagen, das Clustering und die Generierung von Graphen sowie die Klassifizierung von Bildern und Texten verwendet.

Was ist ein Graph in einem neuronalen Graphennetz?

Ein Graph ist eine Datenstruktur, die aus Knoten besteht, und die Verbindungen zwischen den Knoten werden Kanten genannt. Die Kanten können gerichtet und ungerichtet sein. Sie hat dynamische Formen und multidimensionale Strukturen. In den sozialen Medien zum Beispiel sind die Knoten die Menschen in deiner Freundesgruppe und die Kanten die Beziehungen zwischen dir und den anderen. 

Wie leistungsfähig sind Graph Neural Networks?

Graph Neural Networks übertreffen typische Convolutional Neural Networks (CNN) bei der Klassifizierung von Bildern und Knoten. Viele GNN-Varianten haben sowohl bei der Klassifizierung von Knoten als auch von Graphen Spitzenergebnisse erzielt - openreview.net.

Nutzen Neuronale Netze die Graphentheorie?

Ja, Neuronale Netze sind eng mit der Graphentheorie verwandt, die auf nicht-euklidische Daten ausgelegt ist. Einige von ihnen sind selbst Graphen oder geben den Graphen aus. 

Was sind Graph Convolutional Networks?

Graph Convolutional Networks sind ähnlich wie Convolutional Neural Networks, die mit Graph-Datensätzen arbeiten. Sie besteht aus einer Graph-Faltung, einer linearen Schicht und einer nicht-linearen Aktivierung. GNNs lassen Filter über den Graphen laufen und untersuchen Knoten und Kanten, die zur Klassifizierung von Knoten innerhalb der Daten verwendet werden können.

Was ist ein Graph beim Deep Learning?

Graph Deep Learning ist auch als geometrisches Deep Learning bekannt. Es verwendet mehrere neuronale Netzwerkschichten, um eine bessere Leistung zu erzielen. Es ist ein aktiver Forschungsbereich, in dem die Wissenschaftler versuchen, die Anzahl der Schichten zu erhöhen, ohne die Leistung zu beeinträchtigen. 

Themen

Python-Kurse

Zertifizierung verfügbar

Kurs

Intermediate Python

4 hr
1.1M
Verbessere deine Data Science-Fähigkeiten, indem du mit Matplotlib Visualisierungen erstellst und DataFrames mit Pandas manipulierst.
Siehe DetailsRight Arrow
Kurs starten
Mehr anzeigenRight Arrow