Chuyển đến nội dung chính

Hướng dẫn PyTorch CNN: Xây dựng và Huấn luyện Mạng Nơ-ron Tích chập trong Python

Tìm hiểu cách xây dựng và triển khai Mạng Nơ-ron Tích chập (CNN) trong Python với PyTorch.
Đã cập nhật 5 thg 6, 2026  · 13 phút đọc

Mạng Nơ-ron Tích chập (CNN) là nền tảng của thị giác máy tính hiện đại, cho phép các ứng dụng như nhận dạng hình ảnh, phát hiện khuôn mặt và xe tự lái. Các mạng này được thiết kế để tự động trích xuất mẫu và đặc trưng từ hình ảnh, khiến chúng mạnh mẽ hơn các kỹ thuật học máy truyền thống cho các tác vụ thị giác.

Trong hướng dẫn này, chúng ta sẽ triển khai một CNN bằng PyTorch, một framework học sâu thân thiện với người dùng và hiệu quả cao cho cả nghiên cứu lẫn ứng dụng sản xuất.

Yêu cầu tiên quyết: Học sâu và PyTorch

Trước khi đi vào chi tiết về CNN, bạn cần quen thuộc với lĩnh vực học sâu và các thư viện Python mà chúng ta sẽ sử dụng trong quá trình thiết lập môi trường.

Học sâu là một phân nhánh của học máy, trong đó cấu trúc mô hình cơ bản là một mạng gồm đầu vào, các lớp ẩn và đầu ra. Mạng như vậy có thể có một hoặc nhiều lớp ẩn. Trực giác ban đầu đằng sau học sâu là tạo ra các mô hình lấy cảm hứng từ cách não người học: thông qua các tế bào liên kết gọi là nơ-ron. Đây là lý do vì sao chúng ta vẫn gọi các mô hình học sâu là mạng "nơ-ron". Các cấu trúc mô hình theo lớp này cần nhiều dữ liệu hơn hẳn các mô hình học có giám sát khác để học ra mẫu từ dữ liệu phi cấu trúc. Thông thường, chúng ta nói đến ít nhất hàng trăm nghìn điểm dữ liệu.

Mặc dù có nhiều framework và gói để triển khai thuật toán học sâu, chúng ta sẽ tập trung vào PyTorch, một trong những framework phổ biến và được duy trì tốt nhất. Bên cạnh việc được các kỹ sư học sâu trong công nghiệp sử dụng, PyTorch còn là công cụ ưa thích của giới nghiên cứu. Nhiều bài báo học sâu được công bố sử dụng PyTorch. Nó được thiết kế trực quan, thân thiện với người dùng và có nhiều điểm chung với thư viện NumPy của Python. 

Nếu bạn cần kiến thức nhập môn về các khái niệm này, hãy cân nhắc đăng ký khóa học Deep Learning with PyTorch ngay hôm nay. 

Convolutional Neural Network (CNN) là gì?

Mạng nơ-ron tích chập, thường gọi là CNN hoặc ConvNet, là một loại mạng nơ-ron sâu chuyên biệt, rất phù hợp cho các tác vụ thị giác máy tính. CNN được phát minh từ những năm 1980. Tuy nhiên, chúng chỉ trở nên phổ biến rộng rãi trong thập niên 2010, nhờ các đột phá tính toán từ việc triển khai bộ xử lý đồ họa (GPU). Thật vậy, sự phổ biến nhanh chóng của CNN đã giúp lĩnh vực mạng nơ-ron lấy lại vị thế, dẫn đến cái gọi là "làn sóng thứ ba của mạng nơ-ron" mà chúng ta vẫn đang chứng kiến ngày nay.

CNN được lấy cảm hứng trực tiếp từ vỏ não thị giác sinh học. Vỏ não có các vùng nhỏ gồm các tế bào nhạy với những khu vực cụ thể trong trường thị giác. Ý tưởng này đã được mở rộng bởi một thí nghiệm hấp dẫn của Hubel và Wiesel vào năm 1962. 

CNN cố gắng tái hiện đặc điểm này bằng cách tạo ra các mạng nơ-ron phức tạp gồm nhiều lớp chuyên biệt theo nhiệm vụ. CNN được gọi là "truyền thẳng" vì thông tin chảy xuyên qua mô hình. Không có các kết nối phản hồi trong đó đầu ra của mô hình được đưa ngược lại vào chính nó, so với những mô hình khác sử dụng các kỹ thuật như lan truyền ngược.

Cụ thể, một CNN thường bao gồm các lớp sau:

Lớp tích chập

Đây là khối xây dựng đầu tiên của CNN. Như tên gọi, tác vụ toán học chính được thực hiện là tích chập, tức áp dụng một hàm cửa sổ trượt lên ma trận điểm ảnh biểu diễn hình ảnh. Hàm trượt áp lên ma trận được gọi là kernel hay bộ lọc. Ở lớp tích chập, nhiều bộ lọc có cùng kích thước được áp dụng, và mỗi bộ lọc được dùng để nhận diện một mẫu cụ thể từ hình ảnh, chẳng hạn độ uốn của chữ số, các cạnh, toàn bộ hình dạng chữ số, v.v. 

Hàm kích hoạt

Thông thường, một hàm kích hoạt ReLU được áp dụng sau mỗi phép tích chập. Hàm này giúp mạng học được các quan hệ phi tuyến giữa các đặc trưng trong hình ảnh, làm mạng mạnh hơn trong việc nhận diện các mẫu khác nhau. Nó cũng giúp giảm thiểu vấn đề tiêu biến gradient. 

Lớp pooling

Mục tiêu của lớp pooling là rút ra các đặc trưng quan trọng nhất từ ma trận đã tích chập. Điều này được thực hiện bằng cách áp dụng một số phép kết gộp, giúp giảm kích thước của bản đồ đặc trưng (ma trận đã tích chập), từ đó giảm bộ nhớ sử dụng khi huấn luyện mạng.  Pooling cũng hữu ích để giảm hiện tượng quá khớp.

Các lớp kết nối đầy đủ

Các lớp này nằm ở phần cuối của mạng nơ-ron tích chập, và đầu vào của chúng tương ứng với ma trận một chiều đã làm phẳng được tạo ra bởi lớp pooling cuối cùng. Các hàm kích hoạt ReLU được áp dụng để đảm bảo tính phi tuyến. 

Convolution Neural Network Architecture.Kiến trúc Mạng Nơ-ron Tích chập. Nguồn: DataCamp

Bạn có thể đọc giải thích chi tiết hơn về toán học đằng sau CNN trong hướng dẫn của chúng tôi, Convolutional Neural Networks in Python.

Vì sao dùng CNN cho phân loại hình ảnh?

Mạng nơ-ron tích chập là một trong những đổi mới có ảnh hưởng nhất trong lĩnh vực thị giác máy tính. Chúng hoạt động tốt hơn nhiều so với các mô hình học máy truyền thống, như SVMcây quyết định, và đã đạt được các kết quả tiên tiến. 

Hơn nữa, các lớp tích chập mang lại cho CNN đặc tính bất biến theo tịnh tiến, giúp chúng có khả năng nhận diện và trích xuất mẫu, đặc trưng từ dữ liệu bất kể sự thay đổi về vị trí, hướng, tỉ lệ hay tịnh tiến.


CNN đã chứng tỏ thành công trong nhiều nghiên cứu tình huống và ứng dụng đời thực, như:

  • Phân loại hình ảnh, phát hiện đối tượng, phân đoạn, nhận dạng khuôn mặt;
  • Xe tự lái sử dụng hệ thống thị giác dựa trên CNN;
  • Phân loại cấu trúc tinh thể bằng mạng nơ-ron tích chập;
  • Hệ thống camera an ninh.

Vượt ra ngoài nhiệm vụ phân loại hình ảnh, CNN rất linh hoạt và có thể áp dụng cho nhiều lĩnh vực khác, như xử lý ngôn ngữ tự nhiên, phân tích chuỗi thời gian và nhận dạng giọng nói.

Triển khai một CNN với PyTorch

Giờ bạn đã quen với lý thuyết về CNN, chúng ta sẵn sàng bắt tay vào thực hành. Ở phần này, chúng ta sẽ xây dựng và huấn luyện một CNN đơn giản bằng PyTorch. Mục tiêu là xây dựng mô hình phân loại chữ số trong ảnh. Để huấn luyện và kiểm tra mô hình, chúng ta sẽ sử dụng bộ dữ liệu MNIST nổi tiếng, gồm 70.000 ảnh xám 28x28 của chữ số viết tay.

1. Import các thư viện cần thiết

Dưới đây là các thư viện chúng ta sẽ dùng cho hướng dẫn này. Về cốt lõi, chúng ta sẽ tận dụng PyTorch để xây dựng CNN, và mô-đun thị giác máy tính của PyTorch là torchvision để tải xuống và nạp bộ dữ liệu MNIST. Cuối cùng, chúng ta cũng sẽ dùng torchmetrics để đánh giá hiệu năng mô hình.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import torch
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

# !pip install torchvision
import torchvision

import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# !pip install torchmetrics
import torchmetrics

2. Tải và tiền xử lý dữ liệu

PyTorch còn đi kèm một hệ sinh thái phong phú các công cụ và phần mở rộng, bao gồm torchvision, một mô-đun cho thị giác máy tính. Torchvision có sẵn nhiều bộ dữ liệu hình ảnh có thể dùng để huấn luyện và kiểm thử mạng nơ-ron. Trong hướng dẫn này, chúng ta sẽ dùng bộ dữ liệu MNIST. 

Đầu tiên, chúng ta sẽ tải xuống và chuyển đổi bộ dữ liệu MNIST thành tensor, cấu trúc dữ liệu cốt lõi trong PyTorch, tương tự mảng NumPy nhưng có khả năng tăng tốc bằng GPU.

Sau đó, chúng ta cũng sẽ dùng DataLoader để xử lý việc chia lô và xáo trộn cho cả tập train và test. Một DataLoader của PyTorch có thể được tạo từ một Dataset để tải dữ liệu, chia thành các lô và thực hiện các biến đổi trên dữ liệu nếu muốn. Sau đó, nó cung cấp một mẫu dữ liệu sẵn sàng cho huấn luyện. Trong đoạn mã dưới đây, chúng ta tải dữ liệu và lưu vào các DataLoader với kích thước lô 60 ảnh:

batch_size = 60

train_dataset = datasets.MNIST(root="dataset/", download=True, train=True, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root="dataset/", download=True, train=False, transform=transforms.ToTensor())

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Tùy chọn, tập train có thể được chia tiếp thành hai phần train và validation. Validation là kỹ thuật dùng trong học sâu để đánh giá hiệu năng mô hình trong quá trình huấn luyện. Nó giúp phát hiện khả năng quá khớp và thiếu khớp của mô hình, và đặc biệt hữu ích để tối ưu siêu tham số. Tuy nhiên, để đơn giản, chúng ta sẽ không dùng validation trong hướng dẫn này. Nếu muốn tìm hiểu thêm về validation, bạn có thể xem giải thích đầy đủ trong khóa học Introduction to Deep Learning with PyTorch.

Giờ chúng ta đã có dữ liệu, hãy xem một lô ngẫu nhiên các chữ số trông như thế nào:

def imshow(img):
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()

# get some random training images
dataiter = iter(dataloader_train)
images, labels = next(dataiter)
labels
# show images
imshow(torchvision.utils.make_grid(images))

3. Định nghĩa kiến trúc CNN

Để giải bài toán phân loại, chúng ta sẽ tận dụng lớp nn.Module, khối xây dựng của PyTorch để tạo các kiến trúc mạng nơ-ron tinh vi một cách trực quan. 

Trong đoạn mã dưới đây, chúng ta tạo một lớp tên là CNN, kế thừa các thuộc tính của lớp nn.Module. Lớp CNN sẽ là bản thiết kế cho một CNN với hai lớp tích chập, tiếp theo là một lớp kết nối đầy đủ. 

Trong PyTorch, chúng ta dùng nn.Conv2d để định nghĩa lớp tích chập. Ta truyền vào số lượng bản đồ đặc trưng đầu vào và đầu ra. Ta cũng thiết lập một số tham số cho lớp tích chập hoạt động, bao gồm kích thước kernel hay bộ lọc và padding. 

Tiếp theo, chúng ta thêm một lớp gộp cực đại với nn.MaxPool2d. Trong đó, chúng ta trượt một cửa sổ không chồng lấn trên đầu ra của lớp tích chập trước đó. Ở mỗi vị trí, chúng ta chọn giá trị lớn nhất trong cửa sổ để đưa tiếp về sau. Phép toán này làm giảm kích thước không gian của các bản đồ đặc trưng, giảm số lượng tham số và độ phức tạp tính toán của mạng. Cuối cùng, chúng ta thêm một lớp tuyến tính kết nối đầy đủ. 

Hàm forward() định nghĩa cách các lớp khác nhau được kết nối, thêm một số hàm kích hoạt ReLU sau mỗi lớp tích chập.

class CNN(nn.Module):
   def __init__(self, in_channels, num_classes):

       """
       Building blocks of convolutional neural network.

       Parameters:
           * in_channels: Number of channels in the input image (for grayscale images, 1)
           * num_classes: Number of classes to predict. In our problem, 10 (i.e digits from  0 to 9).
       """
       super(CNN, self).__init__()

       # 1st convolutional layer
       self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=3, padding=1)
       # Max pooling layer
       self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
       # 2nd convolutional layer
       self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
       # Fully connected layer
       self.fc1 = nn.Linear(16 * 7 * 7, num_classes)

   def forward(self, x):
       """
       Define the forward pass of the neural network.

       Parameters:
           x: Input tensor.

       Returns:
           torch.Tensor
               The output tensor after passing through the network.
       """
       x = F.relu(self.conv1(x))  # Apply first convolution and ReLU activation
       x = self.pool(x)           # Apply max pooling
       x = F.relu(self.conv2(x))  # Apply second convolution and ReLU activation
       x = self.pool(x)           # Apply max pooling
       x = x.reshape(x.shape[0], -1)  # Flatten the tensor
       x = self.fc1(x)            # Apply fully connected layer
       return x
       x = x.reshape(x.shape[0], -1)  # Flatten the tensor
       x = self.fc1(x)            # Apply fully connected layer
       return x

Sau khi định nghĩa lớp CNN, chúng ta có thể tạo mô hình và chuyển nó tới thiết bị nơi nó sẽ được huấn luyện và chạy. 

Các mạng nơ-ron, bao gồm CNN, thường cho hiệu năng tốt hơn khi chạy trên GPU, nhưng điều đó có thể không đúng với máy của bạn. Do đó, chúng ta sẽ chỉ chạy mô hình trên GPU khi có sẵn; nếu không, sẽ dùng CPU thông thường.

device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN(in_channels=1, num_classes=10).to(device)
print(model)
>>> CNN(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=784, out_features=10, bias=True)
)

4. Huấn luyện mô hình CNN

Giờ khi đã có mô hình, đã đến lúc huấn luyện. Để làm vậy, trước tiên chúng ta cần xác định cách đo lường hiệu năng mô hình. Vì chúng ta đang xử lý bài toán phân loại đa lớp, chúng ta sẽ dùng hàm mất mát cross-entropy, có sẵn trong PyTorch dưới dạng nn.CrossEntropyLoss. Chúng ta cũng sẽ sử dụng bộ tối ưu Adam, một trong những thuật toán tối ưu phổ biến nhất. 

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

Chúng ta sẽ lặp qua mười epoch và các lô huấn luyện để huấn luyện mô hình và thực hiện chuỗi bước thông thường cho mỗi lô như dưới đây.

num_epochs=10
for epoch in range(num_epochs):
 # Iterate over training batches
   print(f"Epoch [{epoch + 1}/{num_epochs}]")

   for batch_index, (data, targets) in enumerate(tqdm(dataloader_train)):
       data = data.to(device)
       targets = targets.to(device)
       scores = model(data)
       loss = criterion(scores, targets)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
Epoch [1/10]
100%|██████████| 1000/1000 [00:13<00:00, 72.94it/s]
Epoch [2/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.27it/s]
Epoch [3/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.16it/s]
Epoch [4/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.00it/s]
Epoch [5/10]
100%|██████████| 1000/1000 [00:13<00:00, 75.69it/s]
Epoch [6/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.24it/s]
Epoch [7/10]
100%|██████████| 1000/1000 [00:12<00:00, 78.23it/s]
Epoch [8/10]
100%|██████████| 1000/1000 [00:12<00:00, 78.16it/s]
Epoch [9/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.96it/s]
Epoch [10/10]
100%|██████████| 1000/1000 [00:12<00:00, 77.93it/s]

5. Đánh giá mô hình

Khi mô hình đã được huấn luyện, chúng ta có thể đánh giá hiệu năng của nó trên tập dữ liệu kiểm thử. Chúng ta sẽ dùng độ chính xác (accuracy), một chỉ số phổ biến cho các bài toán phân loại. Độ chính xác đo tỷ lệ các trường hợp được phân loại đúng trên tổng số đối tượng trong bộ dữ liệu. Nó được tính bằng cách chia số dự đoán đúng cho tổng số dự đoán mà mô hình tạo ra. 

Đầu tiên, chúng ta thiết lập chỉ số accuracy từ torchmetrics. Tiếp theo, chúng ta dùng phương thức .eval của mô hình để đưa mô hình vào chế độ đánh giá, vì một số lớp trong mô hình PyTorch hoạt động khác nhau ở giai đoạn huấn luyện so với kiểm thử. Chúng ta cũng thêm ngữ cảnh Python với torch.no_grad, cho biết sẽ không thực hiện tính toán gradient.

Sau đó, chúng ta lặp qua các ví dụ kiểm thử mà không tính gradient. Với mỗi lô kiểm thử, chúng ta lấy đầu ra mô hình, chọn lớp có khả năng cao nhất và truyền nó vào hàm accuracy cùng với nhãn. Cuối cùng, chúng ta tính toán chỉ số và in kết quả. Chúng ta đạt điểm accuracy 0,98, nghĩa là mô hình phân loại đúng 98% chữ số. Không tệ!

# Set up of multiclass accuracy metric
acc = Accuracy(task="multiclass",num_classes=10)

# Iterate over the dataset batches
model.eval()
with torch.no_grad():
   for images, labels in dataloader_test:
       # Get predicted probabilities for test data batch
       outputs = model(images)
       _, preds = torch.max(outputs, 1)
       acc(preds, labels)
       precision(preds, labels)
       recall(preds, labels)

#Compute total test accuracy
test_accuracy = acc.compute()
print(f"Test accuracy: {test_accuracy}")

>>> Test accuracy: 0.9857000112533569

Bạn cũng có thể dùng các chỉ số phân loại phổ biến khác, gồm recall và precision. Chúng tôi trình bày đầy đủ về các chỉ số này với ví dụ thực tiễn trong khóa học Intermediate Deep Learning with PyTorch.

Cải thiện hiệu năng mô hình

Mặc dù mô hình CNN của chúng ta đạt hiệu năng mạnh mẽ, vẫn có một số chiến lược có thể dùng để tiếp tục nâng cao độ chính xác, độ vững và khả năng khái quát hóa cho dữ liệu mới. 

Trong phần này, chúng ta sẽ khám phá các kỹ thuật chủ chốt như tăng cường dữ liệu, tinh chỉnh siêu tham số và học chuyển giao để tối ưu hiệu năng của mô hình.

Kỹ thuật tăng cường dữ liệu

Tăng cường dữ liệu là kỹ thuật dùng để cải thiện độ chính xác của mô hình bằng cách ngẫu nhiên tạo thêm dữ liệu huấn luyện mới. Ví dụ, trong quá trình tải, có thể áp dụng các biến đổi lên ảnh huấn luyện, như thay đổi kích thước, lật ngang hoặc dọc, xoay ngẫu nhiên, v.v. Bằng cách đó, có thể tạo ra các ảnh tăng cường và gán cho chúng nhãn giống ảnh gốc, từ đó tăng kích thước tập huấn luyện.

Việc thêm các biến đổi ngẫu nhiên vào ảnh gốc cho phép chúng ta tạo ra nhiều dữ liệu hơn đồng thời tăng kích thước và độ đa dạng của tập huấn luyện. Nó giúp mô hình vững vàng hơn trước các biến thiên và méo dạng thường gặp trong ảnh thực tế, và giảm quá khớp khi mô hình học cách bỏ qua các biến đổi ngẫu nhiên. 

Tuy nhiên, cần thận trọng với tăng cường dữ liệu, vì đôi khi nó có thể gây hại cho quá trình huấn luyện. Chẳng hạn, trong bài toán của chúng ta, nếu áp dụng lật dọc với số "6", nó sẽ trông giống số "9". Truyền nó vào mô hình với nhãn "6" sẽ gây nhiễu mô hình và cản trở việc huấn luyện. Những ví dụ này cho thấy đôi khi các phép tăng cường cụ thể có thể ảnh hưởng đến nhãn.

Tinh chỉnh siêu tham số

Một chiến lược khác để cải thiện hiệu năng mô hình là thay đổi giá trị của các siêu tham số liên quan đến các lớp khác nhau của mô hình. Việc tinh chỉnh siêu tham số này đòi hỏi hiểu biết sâu về toán học đằng sau mạng nơ-ron và ý nghĩa của các siêu tham số khác nhau. 

Ví dụ, bạn có thể tinh chỉnh các lớp CNN bằng cách thay đổi kích thước bộ lọc hoặc tăng padding. Bạn cũng có thể đặt giá trị khác cho trọng số khởi tạo của các nơ-ron. 

Vì chúng ta sẽ không biết trước các giá trị tối ưu của siêu tham số, sẽ cần một mức độ thử nghiệm. Điều này thường được thực hiện qua kỹ thuật gọi là grid search, cho phép bạn đánh giá có hệ thống một mô hình trên một lưới các giá trị tham số. 

Tuy nhiên, hãy lưu ý khi dùng kỹ thuật này, vì nó thường tốn nhiều tài nguyên tính toán, đặc biệt khi làm việc với mạng nơ-ron phức tạp và bộ dữ liệu huấn luyện lớn.

Tương tự, bạn có thể tăng độ phức tạp của mô hình bằng cách thêm nhiều lớp tích chập và tuyến tính. Tuy vậy, hãy cẩn trọng khi thêm lớp mới, vì số lượng nơ-ron có thể tăng mạnh, dẫn đến thời gian huấn luyện dài hơn và nguy cơ quá khớp.

Bạn có thể tìm hiểu thêm về tinh chỉnh siêu tham số trong khóa học Introduction to Deep Learning with PyTorch.

Sử dụng mô hình tiền huấn luyện

Huấn luyện các mô hình học sâu từ đầu là một quá trình dài và tẻ nhạt, và thường đòi hỏi rất nhiều dữ liệu huấn luyện. Thay vào đó, chúng ta thường có thể sử dụng các mô hình tiền huấn luyện, tức là các mô hình đã được huấn luyện sẵn cho một số tác vụ nào đó. 

Đôi khi, chúng ta có thể tái sử dụng trực tiếp một mô hình tiền huấn luyện nếu nó đã có thể giải quyết tác vụ mà chúng ta quan tâm. Ở những trường hợp khác, chúng ta có thể cần điều chỉnh mô hình tiền huấn luyện để phù hợp với tác vụ mới. Điều này được gọi là học chuyển giao.

Sử dụng các mô hình tiền huấn luyện trong PyTorch khá dễ dàng. Torchvision cung cấp một tập hợp các mô hình tiền huấn luyện cho nhiều tác vụ liên quan đến hình ảnh. Những mô hình này được huấn luyện trên các bộ dữ liệu hình ảnh quy mô lớn và có sẵn để sử dụng. Hãy xem khóa học Deep Learning for Images with PyTorch để học mọi điều bạn cần biết về chúng.

Triển khai mô hình CNN

Sau khi huấn luyện mô hình phân loại có độ chính xác cao bằng PyTorch, bạn có thể lưu mô hình và trọng số đã huấn luyện để sử dụng sau và chia sẻ với nhóm của mình, đảm bảo họ có thể nạp lại một cách liền mạch.

Để lưu mô hình, chúng ta có thể dùng torch.save. Phần mở rộng tệp phổ biến cho các mô hình torch là pt hoặc pth. Để lưu trọng số của mô hình, chúng ta truyền model.state_dict vào torch.save kèm tên tệp đầu ra, ví dụ MulticlassCNN.pth.

Để nạp một mô hình đã lưu, chúng ta khởi tạo một mô hình mới với cùng kiến trúc. Sau đó dùng phương thức load state dict cùng với torch.load để nạp các tham số vào mô hình mới.

# Save the model
torch.save(model.state_dict(), 'MulticlassCNN.pth')

# Create a new model
loaded_model = CNN(in_channels=1, num_classes=10)

# Load the saved model
loaded_model.load_state_dict(torch.load('MulticlassCNN.pth'))
print(loaded_model)


CNN(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=784, out_features=10, bias=True)
)

Kết luận

Chúng ta đã bao quát tổng quan đầy đủ về CNN, cung cấp chi tiết về từng lớp trong kiến trúc CNN. Tiếp đó, chúng ta đưa ra hướng dẫn cách triển khai một CNN trong PyTorch, bao trùm các bước chính, từ tải dữ liệu và thiết kế mô hình đến huấn luyện và đánh giá mô hình. Cuối cùng, chúng ta cũng phân tích một số chiến lược để cải thiện hiệu năng mô hình. Chúng ta đã áp dụng toàn bộ kỹ năng này cho một kịch bản thực tế liên quan đến tác vụ phân loại đa lớp. 

Có rất nhiều điều để học về học sâu, có lẽ là một trong những lĩnh vực thú vị và đầy thách thức nhất của AI. May mắn thay, DataCamp ở đây để hỗ trợ. Hãy xem các tài liệu và khóa học chuyên sâu của chúng tôi và trở thành chuyên gia về mạng nơ-ron:


Javier Canales Luna's photo
Author
Javier Canales Luna
LinkedIn

Tôi là một chuyên viên phân tích dữ liệu tự do, hợp tác với các công ty và tổ chức trên toàn thế giới trong các dự án khoa học dữ liệu. Tôi cũng là giảng viên khoa học dữ liệu với hơn 2 năm kinh nghiệm. Tôi thường xuyên viết bài về khoa học dữ liệu bằng tiếng Anh và tiếng Tây Ban Nha; một số bài đã được đăng trên các trang uy tín như DataCamp, Towards Data Science và Analytics Vidhya. Là một nhà khoa học dữ liệu có nền tảng khoa học chính trị và luật, mục tiêu của tôi là làm việc tại giao điểm giữa chính sách công, pháp luật và công nghệ, tận dụng sức mạnh của ý tưởng để thúc đẩy các giải pháp và cách tiếp cận mới nhằm giúp chúng ta đối mặt với những thách thức cấp bách, đặc biệt là khủng hoảng khí hậu. Tôi xem mình là người tự học, không ngừng trau dồi và là một người ủng hộ vững chắc cho tính đa ngành. Không bao giờ là quá muộn để học điều mới.

Chủ đề

Những khóa học hàng đầu trên DataCamp

Courses

Nhập môn Deep Learning với Python

4 giờ
263.4K
Tìm hiểu các kiến thức cơ bản về mạng nơ-ron và cách xây dựng mô hình học sâu bằng Keras 2.0 trong Python.
Xem chi tiếtRight Arrow
Bắt đầu khóa học
Xem thêmRight Arrow