Chuyển đến nội dung chính

Giới thiệu toàn diện về Mạng nơ-ron Đồ thị (GNN)

Tìm hiểu mọi thứ về Mạng nơ-ron Đồ thị, bao gồm GNN là gì, các loại mạng nơ-ron đồ thị khác nhau và chúng được dùng để làm gì. Ngoài ra, học cách xây dựng một Mạng nơ-ron Đồ thị với Pytorch.
Đã cập nhật 5 thg 6, 2026  · 15 phút đọc

GNN

Đồ thị là gì?

Đồ thị là một kiểu cấu trúc dữ liệu gồm các nút (node) và cạnh (edge). Một nút có thể là người, địa điểm hoặc vật, và các cạnh xác định mối quan hệ giữa các nút. Cạnh có thể có hướng hoặc vô hướng dựa trên sự phụ thuộc theo hướng. 

Trong ví dụ bên dưới, các vòng tròn màu xanh là các nút và các mũi tên là các cạnh. Hướng của cạnh xác định sự phụ thuộc giữa hai nút. 

dependencies between two nodes

Hình minh họa bởi Tác giả 

Hãy cùng tìm hiểu bộ dữ liệu Đồ thị phức tạp: Mạng lưới Nhạc sĩ Jazz. Nó chứa 198 nút và 2742 cạnh. Trong biểu đồ đồ thị cộng đồng bên dưới, các màu nút khác nhau biểu thị các cộng đồng nhạc sĩ Jazz khác nhau và các cạnh kết nối giữa họ. Có một mạng lưới hợp tác nơi một nhạc sĩ có mối quan hệ trong và ngoài cộng đồng. 

Community Graph

 Biểu đồ Đồ thị Cộng đồng bởi Mạng lưới Nhạc sĩ Jazz

Đồ thị rất hiệu quả khi xử lý các bài toán phức tạp có mối quan hệ và tương tác. Chúng được dùng trong nhận dạng mẫu, phân tích mạng xã hội, hệ thống gợi ý và phân tích ngữ nghĩa. Xây dựng các giải pháp dựa trên đồ thị là một lĩnh vực mới, cung cấp nhiều góc nhìn sâu sắc về những bộ dữ liệu phức tạp và liên kết chặt chẽ.

Đồ thị với NetworkX

Trong phần này, chúng ta sẽ học cách tạo một đồ thị bằng NetworkX

Đoạn mã dưới đây chịu ảnh hưởng từ bài blog của Daniel Holmberg về Mạng nơ-ron Đồ thị trong Python.

  1. Tạo đối tượng DiGraph của networkx “H”
  2. Thêm các nút với nhãn, màu sắc và kích thước khác nhau
  3. Thêm các cạnh để tạo mối quan hệ giữa hai nút. Ví dụ, “(0,1)” nghĩa là 0 phụ thuộc có hướng vào 1. Chúng ta sẽ tạo quan hệ hai chiều bằng cách thêm “(1,0)”
  4. Trích xuất màu và kích thước dưới dạng danh sách
  5. Vẽ đồ thị bằng hàm draw của networkx
import networkx as nx
H = nx.DiGraph()

#adding nodes
H.add_nodes_from([
  (0, {"color": "blue", "size": 250}),

  (1, {"color": "yellow", "size": 400}),

  (2, {"color": "orange", "size": 150}),

  (3, {"color": "red", "size": 600})


])

#adding edges
H.add_edges_from([
  (0, 1),

  (1, 2),

  (1, 0),

  (1, 3),

  (2, 3),

  (3,0)


])

node_colors = nx.get_node_attributes(H, "color").values()
colors = list(node_colors)
node_sizes = nx.get_node_attributes(H, "size").values()
sizes = list(node_sizes)

#Plotting Graph
nx.draw(H, with_labels=True, node_color=colors, node_size=sizes)

undirectional graph

Ở bước tiếp theo, chúng ta sẽ chuyển cấu trúc dữ liệu từ đồ thị có hướng sang đồ thị vô hướng bằng hàm to_undirected()

#converting to undirected graph
G = H.to_undirected()
nx.draw(G, with_labels=True, node_color=colors, node_size=sizes)

Tại sao phân tích đồ thị lại khó?

Các cấu trúc dữ liệu dựa trên đồ thị có những hạn chế, và nhà khoa học dữ liệu cần hiểu rõ trước khi phát triển giải pháp dựa trên đồ thị.

  1. Đồ thị tồn tại trong không gian phi Euclid. Nó không tồn tại trong không gian 2D hoặc 3D, khiến việc diễn giải dữ liệu khó hơn. Để trực quan hóa cấu trúc trong không gian 2D, bạn phải dùng các công cụ giảm chiều khác nhau.
  2. Đồ thị có tính động; chúng không có dạng cố định. Có thể có hai đồ thị trông khác nhau nhưng lại có biểu diễn ma trận kề tương tự. Điều này khiến chúng ta khó phân tích dữ liệu bằng các công cụ thống kê truyền thống. 
  3. Kích thước lớn và số chiều cao sẽ làm tăng độ phức tạp của đồ thị đối với việc diễn giải của con người. Cấu trúc dày đặc với nhiều nút và hàng nghìn cạnh khó hiểu và khó rút ra insight. 

Mạng nơ-ron Đồ thị (GNN) là gì?

Mạng nơ-ron Đồ thị là các loại mạng nơ-ron đặc biệt có khả năng làm việc với cấu trúc dữ liệu dạng đồ thị. Chúng chịu ảnh hưởng lớn từ Mạng nơ-ron Tích chập (CNN) và kỹ thuật nhúng đồ thị. GNN được dùng để dự đoán nút, cạnh và các tác vụ dựa trên đồ thị. 

  • CNN được dùng cho phân loại ảnh. Tương tự, GNN được áp dụng lên cấu trúc đồ thị (lưới điểm ảnh) để dự đoán một lớp. 
  • Mạng nơ-ron Hồi quy được dùng trong phân loại văn bản. Tương tự, GNN được áp dụng lên cấu trúc đồ thị nơi mỗi từ là một nút trong câu.  

GNN ra đời khi Mạng nơ-ron Tích chập không đạt được kết quả tối ưu do kích thước tùy ý của đồ thị và cấu trúc phức tạp. 

Neural network

Hình ảnh bởi Purvanshi Mehta

Đồ thị đầu vào được đưa qua một chuỗi các mạng nơ-ron. Cấu trúc đồ thị đầu vào được chuyển thành nhúng đồ thị, cho phép chúng ta giữ lại thông tin về các nút, cạnh và ngữ cảnh tổng thể. 

Sau đó, véc-tơ đặc trưng của các nút AC được đưa qua lớp mạng nơ-ron. Nó tổng hợp các đặc trưng này và chuyển sang lớp tiếp theo - neptune.ai.

Đọc hướng dẫn Deep Learning của chúng tôi hoặc tham gia khóa học Giới thiệu về Học sâu để tìm hiểu thêm về các thuật toán và ứng dụng học sâu. 

Các loại Mạng nơ-ron Đồ thị

Có nhiều loại mạng nơ-ron, và phần lớn trong số đó là biến thể của Mạng nơ-ron Tích chập. Trong phần này, chúng ta sẽ tìm hiểu về những GNN phổ biến nhất. 

  • Mạng Tích chập Đồ thị (GCN) tương tự như CNN truyền thống. Nó học đặc trưng bằng cách quan sát các nút lân cận. GNN tổng hợp các véc-tơ nút, đưa kết quả vào lớp dày (dense), và áp dụng phi tuyến tính bằng hàm kích hoạt. Tóm lại, nó gồm tích chập đồ thị, lớp tuyến tính và hàm kích hoạt phi tuyến. Có hai loại GCN chính: Mạng Tích chập Không gian và Mạng Tích chập Phổ.
  • Mạng Tự mã hóa Đồ thị học biểu diễn đồ thị bằng bộ mã hóa (encoder) và cố gắng tái tạo đồ thị đầu vào bằng bộ giải mã (decoder). Encoder và decoder được nối với nhau qua một lớp nút cổ chai. Chúng thường được dùng trong dự đoán liên kết vì Auto-Encoder xử lý tốt mất cân bằng lớp. 
  • Mạng nơ-ron Đồ thị Hồi quy (RGNN) học mẫu khuếch tán tối ưu và có thể xử lý đồ thị đa quan hệ, nơi một nút có nhiều mối quan hệ. Loại mạng này dùng các bộ điều chuẩn (regularizer) để tăng độ mượt và loại bỏ hiện tượng quá tham số. RGNN dùng ít tài nguyên tính toán hơn để cho kết quả tốt hơn. Chúng được dùng trong sinh văn bản, dịch máy, nhận dạng giọng nói, tạo mô tả ảnh, gắn thẻ video và tóm tắt văn bản.
  • Mạng nơ-ron Đồ thị có Cổng (GGNN) tốt hơn RGNN trong các tác vụ có phụ thuộc dài hạn. GGNN cải tiến RGNN bằng cách thêm các cổng theo nút, cạnh và thời gian cho các phụ thuộc dài hạn. Tương tự như các Đơn vị Hồi quy có Cổng (GRU), các cổng được dùng để nhớ và quên thông tin ở các trạng thái khác nhau. 

Nếu bạn quan tâm tìm hiểu thêm về Mạng nơ-ron Hồi quy (RNN), hãy xem khóa học của DataCamp. Khóa học sẽ giới thiệu cho bạn các kiến trúc mô hình RNN khác nhau, khung Keras và ứng dụng của RNN. 

Các loại tác vụ của Mạng nơ-ron Đồ thị

Dưới đây, chúng tôi phác thảo một số loại tác vụ GNN kèm ví dụ:

  • Phân loại Đồ thị: dùng để phân loại các đồ thị vào nhiều nhóm. Ứng dụng gồm phân tích mạng xã hội và phân loại văn bản. 
  • Phân loại Nút: tác vụ này dùng nhãn của các nút lân cận để dự đoán các nhãn nút còn thiếu trong đồ thị. 
  • Dự đoán Liên kết: dự đoán liên kết giữa một cặp nút trong đồ thị với ma trận kề không đầy đủ. Thường được dùng cho mạng xã hội. 
  • Phát hiện Cộng đồng: chia các nút thành nhiều cụm khác nhau dựa trên cấu trúc cạnh. Nó học từ trọng số cạnh, khoảng cách và các đối tượng đồ thị tương tự. 
  • Nhúng Đồ thị: ánh xạ đồ thị thành các véc-tơ, bảo toàn thông tin liên quan về nút, cạnh và cấu trúc.
  • Sinh Đồ thị: học từ phân phối mẫu đồ thị để tạo ra cấu trúc đồ thị mới nhưng tương tự. 

Types of Graph Neural Networks

Hình minh họa bởi Tác giả

Nhược điểm của Mạng nơ-ron Đồ thị

Có một vài nhược điểm khi sử dụng GNN. Hiểu chúng sẽ giúp chúng ta quyết định khi nào nên dùng GNN và cách tối ưu hiệu năng mô hình học máy. 

  1. Hầu hết các mạng nơ-ron có thể đi sâu để đạt hiệu năng tốt hơn, trong khi GNN là mạng nông, thường chỉ có ba lớp. Điều này hạn chế chúng ta đạt hiệu năng tiên tiến trên các bộ dữ liệu lớn.
  2. Cấu trúc đồ thị thay đổi liên tục, khiến việc huấn luyện mô hình trở nên khó khăn. 
  3. Triển khai mô hình vào sản xuất gặp vấn đề mở rộng vì các mạng này tiêu tốn nhiều tài nguyên tính toán. Nếu bạn có một cấu trúc đồ thị lớn và phức tạp, việc mở rộng GNN trong môi trường sản xuất sẽ khó khăn. 

Mạng Tích chập Đồ thị (GCN) là gì?

Phần lớn GNN là Mạng Tích chập Đồ thị, và điều quan trọng là phải tìm hiểu về chúng trước khi bắt tay vào hướng dẫn phân loại nút.  

Tích chập trong GCN giống như tích chập trong các mạng nơ-ron tích chập. Nó nhân các nơ-ron với trọng số (bộ lọc) để học từ các đặc trưng dữ liệu. 

Nó hoạt động như các cửa sổ trượt trên toàn bộ ảnh để học đặc trưng từ các ô lân cận. Bộ lọc sử dụng chia sẻ trọng số để học các đặc điểm khuôn mặt khác nhau trong hệ thống nhận dạng ảnh - Towards Data Science

Bây giờ chuyển cùng chức năng đó sang Mạng Tích chập Đồ thị, nơi mô hình học đặc trưng từ các nút lân cận. Khác biệt lớn giữa GCN và CNN là GCN được phát triển để làm việc với cấu trúc dữ liệu phi Euclid, nơi thứ tự của các nút và cạnh có thể thay đổi. 

CNN vs GCN

CNN vs GCN | Nguồn

Tìm hiểu thêm về CNN cơ bản qua hướng dẫn Mạng nơ-ron Tích chập (CNN) với TensorFlow

Có hai loại GCN: 

  • Mạng Tích chập Đồ thị Không gian sử dụng các đặc trưng không gian để học từ các đồ thị nằm trong không gian không gian.  
  • Mạng Tích chập Đồ thị Phổ sử dụng phân rã riêng trị của ma trận Laplace đồ thị để lan truyền thông tin dọc theo các nút. Các mạng này được lấy cảm hứng từ sự lan truyền sóng trong tín hiệu và hệ thống. 

GNN hoạt động như thế nào? Xây dựng Mạng nơ-ron Đồ thị với Pytorch

Chúng ta sẽ xây dựng và huấn luyện Tích chập Đồ thị Phổ cho mô hình phân loại nút. Mã nguồn có sẵn trong workbook DataLab này để bạn trải nghiệm và chạy mô hình học máy dựa trên đồ thị đầu tiên của mình. 

Các ví dụ mã chịu ảnh hưởng từ tài liệu của Pytorch geometric

Bắt đầu

Chúng ta sẽ cài đặt gói Pytorch pytorch_geometric được xây dựng dựa trên nó. 

!pip install -q torch

Sau đó, chúng ta sẽ dùng phiên bản torch để cài đặt torch-scattertorch-sparse. Tiếp theo, chúng ta sẽ cài đặt bản phát hành mới nhất của pytorch_geometric từ GitHub. 

%%capture
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['PYTHONWARNINGS'] = "ignore"
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

Bộ dữ liệu Planetoid Cora

Planetoid là bộ dữ liệu mạng trích dẫn từ Cora, CiteSeer và PubMed. Các nút là các tài liệu với véc-tơ đặc trưng túi từ (bag-of-words) 1433 chiều, và các cạnh là liên kết trích dẫn giữa các bài báo nghiên cứu. Có 7 lớp, và chúng ta sẽ huấn luyện mô hình để dự đoán các nhãn còn thiếu. 

Chúng ta sẽ nạp bộ dữ liệu Planetoid Cora và chuẩn hóa theo hàng các đặc trưng đầu vào dạng túi từ. Sau đó, chúng ta sẽ phân tích bộ dữ liệu và đối tượng đồ thị đầu tiên. 

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.
print(data)

Bộ dữ liệu Cora có 2708 nút, 10.556 cạnh, 1433 đặc trưng7 lớp. Đối tượng đầu tiên có 2708 mặt nạ train, validation và test. Chúng ta sẽ dùng các mặt nạ này để huấn luyện và đánh giá mô hình. 

Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

Phân loại Nút với GNN

Chúng ta sẽ tạo cấu trúc mô hình GCN gồm hai lớp GCNConv, kích hoạt relu và tỷ lệ dropout 0,5. Mô hình có 16 kênh ẩn.  

Lớp GCN:

GCN layer

W(ℓ+1) là ma trận trọng số có thể huấn luyện trong phương trình trên và Cw,v biểu thị hệ số chuẩn hóa cố định cho mỗi cạnh.

from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)
print(model)

>>> GCN(
    (conv1): GCNConv(1433, 16)
    (conv2): GCNConv(16, 7)
  )

Trực quan hóa GCN chưa huấn luyện

Hãy trực quan hóa nhúng nút của các mạng GCN chưa huấn luyện bằng sklearn.manifold.TSNE và matplotlib.pyplot. Nó sẽ vẽ nhúng nút 7 chiều trên biểu đồ tán xạ 2D.  

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

Chúng ta sẽ đưa mô hình vào chế độ đánh giá rồi thêm dữ liệu huấn luyện vào mô hình chưa huấn luyện để trực quan hóa các nút và nhóm khác nhau. 

model.eval()

out = model(data.x, data.edge_index)
visualize(out, color=data.y)

untrained model

Huấn luyện GNN

Chúng ta sẽ huấn luyện mô hình trong 100 Epoch bằng tối ưu hóa Adam và hàm mất mát Cross-Entropy

Trong hàm train, chúng ta:

  1. Xóa gradient
  2. Thực hiện một lượt truyền xuôi
  3. Tính loss bằng các nút huấn luyện
  4. Tính gradient và cập nhật tham số

Trong hàm test, chúng ta:

  1. Dự đoán lớp của nút
  2. Trích xuất nhãn lớp có xác suất cao nhất
  3. Kiểm tra có bao nhiêu giá trị được dự đoán đúng
  4. Tạo tỷ lệ chính xác bằng tổng số dự đoán đúng chia cho tổng số nút. 
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = criterion(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test():
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)
      test_correct = pred[data.test_mask] == data.y[data.test_mask]
      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
      return test_acc


for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=8)
)

.. .. .. ..
.. .. .. ..
Epoch: 098, Loss: 0.5989
Epoch: 099, Loss: 0.6021
Epoch: 100, Loss: 0.5799

Đánh giá mô hình

Bây giờ chúng ta sẽ đánh giá mô hình trên tập dữ liệu chưa thấy bằng hàm test, và như bạn thấy, chúng ta đạt kết quả khá tốt với độ chính xác 81,5%.  

test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')


>>> Test Accuracy: 0.8150

Giờ chúng ta sẽ trực quan hóa nhúng đầu ra của mô hình đã huấn luyện để kiểm chứng kết quả.

model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)

Như có thể thấy, mô hình đã huấn luyện cho ra sự gom cụm các nút tốt hơn theo cùng hạng mục.

clustering of nodes

Huấn luyện mô hình GATConv

Ở phần thứ hai, chúng ta sẽ thay GCNConv bằng các lớp GATConv. Mạng Chú ý Đồ thị sử dụng các lớp tự chú ý có che mặt nạ để khắc phục nhược điểm của GCNConv và đạt kết quả tiên tiến. 

Bạn cũng có thể thử các lớp GNN khác và thử nghiệm với tối ưu hóa, dropout và số kênh ẩn để đạt hiệu năng tốt hơn. 

Trong đoạn mã dưới đây, chúng ta chỉ thay GCNConv bằng GATConv với 8 đầu chú ý ở lớp đầu tiên và 1 ở lớp thứ hai. 

Chúng ta cũng sẽ đặt:

  • tỷ lệ dropout là 0,6
  • kênh ẩn là 8
  • tốc độ học 0,005

Chúng ta đã sửa đổi hàm test để tìm độ chính xác của một mặt nạ cụ thể (validation, test). Điều này giúp in ra điểm validation và test trong quá trình huấn luyện. Chúng ta cũng lưu kết quả validation và test để vẽ biểu đồ đường sau đó. 

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(dataset.num_features, hidden_channels,heads)
        self.conv2 = GATConv(heads*hidden_channels, dataset.num_classes,heads)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=8, heads=8)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = criterion(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)
      correct = pred[mask] == data.y[mask]
      acc = int(correct.sum()) / int(mask.sum())
      return acc

val_acc_all = []
test_acc_all = []

for epoch in range(1, 101):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    val_acc_all.append(val_acc)
    test_acc_all.append(test_acc)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

.. .. .. ..
.. .. .. ..
Epoch: 098, Loss: 1.1283, Val: 0.7960, Test: 0.8030

Epoch: 099, Loss: 1.1352, Val: 0.7940, Test: 0.8050

Epoch: 100, Loss: 1.1053, Val: 0.7960, Test: 0.8040

Như chúng ta quan sát, mô hình của chúng ta không hiệu quả hơn GCNConv. Cần tối ưu siêu tham số hoặc nhiều Epoch hơn để đạt kết quả tiên tiến. 

Đánh giá mô hình

Ở phần đánh giá, chúng ta trực quan hóa điểm số validation và testing bằng biểu đồ đường của matplotlib.pyplot.  

import numpy as np

plt.figure(figsize=(12,8))
plt.plot(np.arange(1, len(val_acc_all) + 1), val_acc_all, label='Validation accuracy', c='blue')
plt.plot(np.arange(1, len(test_acc_all) + 1), test_acc_all, label='Testing accuracy', c='red')
plt.xlabel('Epochs')
plt.ylabel('Accurarcy')
plt.title('GATConv')
plt.legend(loc='lower right', fontsize='x-large')
plt.savefig('gat_loss.png')
plt.show()

Sau 60 Epoch, độ chính xác validation và testing đã đạt giá trị ổn định khoảng 0,8 +/-0,02. 

GATConv model

Một lần nữa, hãy trực quan hóa sự gom cụm nút của mô hình GATConv.

model.eval()

out = model(data.x, data.edge_index)
visualize(out, color=data.y)

Như có thể thấy, lớp GATConv tạo ra kết quả tương tự về gom cụm trên cùng hạng mục các nút. 

clustering

Chúng ta có thể giảm overfitting bằng cách thêm tập dữ liệu validation thứ hai và cải thiện hiệu năng mô hình bằng cách thử nghiệm các lớp GCN khác nhau từ pytoch_geometric

Mã nguồn của hướng dẫn có trong workbook DataLab này. Hãy tạo một bản sao workbook để bạn có thể chạy.

Bổ sung kỹ năng Deep Learning vào Sơ yếu lý lịch của bạn bằng cách tham gia lộ trình kỹ năng Deep Learning bằng Python. Khóa học sẽ giới thiệu cho bạn các thuật toán học sâu, Keras, Pytorch và khung Tensorflow. 

Câu hỏi thường gặp

GNN được dùng để làm gì?

Mạng nơ-ron Đồ thị được áp dụng trực tiếp lên các bộ dữ liệu đồ thị và bạn có thể huấn luyện chúng để dự đoán các tác vụ liên quan đến nút, cạnh và đồ thị. Chúng được dùng cho phân loại đồ thị và nút, dự đoán liên kết, gom cụm và sinh đồ thị, cũng như phân loại ảnh và văn bản.

Đồ thị trong Mạng nơ-ron Đồ thị là gì?

Đồ thị là một cấu trúc dữ liệu gồm các nút và các kết nối giữa chúng được gọi là cạnh. Cạnh có thể có hướng hoặc vô hướng. Nó có hình dạng động và cấu trúc đa chiều. Ví dụ, trên mạng xã hội, các nút là những người trong nhóm bạn bè của bạn, và các cạnh là mối quan hệ giữa bạn và mọi người khác. 

GNN mạnh đến mức nào?

Mạng nơ-ron Đồ thị vượt trội hơn các Mạng nơ-ron Tích chập (CNN) điển hình trong phân loại ảnh và nút. Nhiều biến thể GNN đã đạt kết quả tiên tiến ở cả tác vụ phân loại nút và đồ thị - openreview.net.  

Mạng nơ-ron có dùng lý thuyết đồ thị không?

Có, Mạng nơ-ron có liên quan chặt chẽ đến lý thuyết đồ thị và được thiết kế để làm việc trên dữ liệu phi Euclid. Một số trong đó bản thân là đồ thị hoặc xuất ra đồ thị. 

Mạng Tích chập Đồ thị là gì?

Mạng Tích chập Đồ thị tương tự như Mạng nơ-ron Tích chập hoạt động với các bộ dữ liệu đồ thị. Nó bao gồm tích chập đồ thị, lớp tuyến tính và kích hoạt phi tuyến. GNN truyền các bộ lọc qua đồ thị, quan sát các nút và cạnh có thể dùng để phân loại các nút trong dữ liệu.

Đồ thị trong Học sâu là gì?

Học sâu Đồ thị còn được gọi là Học sâu Hình học. Nó sử dụng nhiều lớp mạng nơ-ron để đạt hiệu năng tốt hơn. Đây là một lĩnh vực nghiên cứu sôi động, nơi các nhà khoa học đang cố gắng tăng số lớp mà không làm giảm hiệu năng. 

Chủ đề

Khóa học Python

Courses

Python nâng cao

4 giờ
1.4M
Nâng cao kỹ năng Khoa học dữ liệu của bạn bằng cách tạo trực quan hóa với Matplotlib và thao tác DataFrame bằng pandas.
Xem chi tiếtRight Arrow
Bắt đầu khóa học
Xem thêmRight Arrow