Fine tuning là gì

1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn hồ hết ai thao tác làm việc với các model trong deep learning các đã nghe/quen với có mang Transfer learning với Fine tuning. Quan niệm tổng quát: Transfer learning là tận dụng học thức học được từ là một vấn đề để áp dụng vào 1 vấn đề có liên quan khác. Một ví dụ solo giản: thay vị train 1 model mới hoàn toàn cho vấn đề phân một số loại chó/mèo, người ta rất có thể tận dụng 1 mã sản phẩm đã được train bên trên ImageNet dataset cùng với hằng triệu ảnh. Pre-trained model này sẽ tiến hành train tiếp bên trên tập dataset chó/mèo, quy trình train này ra mắt nhanh hơn, hiệu quả thường xuất sắc hơn. Có khá nhiều kiểu Transfer learning, các chúng ta có thể tham khảo trong bài xích này: Tổng thích hợp Transfer learning. Trong bài bác này, mình sẽ viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine tuning là gì

Hiểu đối kháng giản, fine-tuning là chúng ta lấy 1 pre-trained model, tận dụng một phần hoặc cục bộ các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo ra 1 mã sản phẩm mới. Thường những layer đầu của model được freeze (đóng băng) lại - tức weight những layer này sẽ không còn bị đổi khác giá trị trong quy trình train. Lý do bởi những layer này đã có khả năng trích xuất thông tin mức trìu tượng thấp , năng lực này được học tập từ quá trình training trước đó. Ta freeze lại nhằm tận dụng được khả năng này cùng giúp việc train diễn ra nhanh rộng (model chỉ phải update weight ở những layer cao). Có rất nhiều các Object detect mã sản phẩm được thiết kế dựa trên những Classifier model. VD Retina model (Object detect) được xuất bản với backbone là Resnet.

*

1.2 vì sao pytorch thay bởi vì Keras ?

Chủ đề bài viết hôm nay, mình sẽ hướng dẫn fine-tuning Resnet50 - 1 pre-trained model được hỗ trợ sẵn trong torchvision của pytorch. Vì sao là pytorch mà không phải Keras ? vì sao bởi bài toán fine-tuning mã sản phẩm trong keras rất đối kháng giản. Dưới đó là 1 đoạn code minh hoạ cho việc xây dựng 1 Unet dựa vào Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Bạn rất có thể thấy, fine-tuning mã sản phẩm trong Keras thực thụ rất đơn giản, dễ dàng làm, dễ dàng hiểu. Việc showroom thêm các nhánh rất dễ dàng bởi cú pháp solo giản. Vào pytorch thì ngược lại, thành lập 1 model Unet tương tự sẽ rất vất vả cùng phức tạp. Fan mới học sẽ gặp mặt khó khăn vì chưng trên mạng ko nhiều những hướng dẫn cho câu hỏi này. Vậy nên bài xích này mình sẽ hướng dẫn chi tiết cách fine-tune vào pytorch để vận dụng vào bài toán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?

*

Khi nhìn vào 1 bức ảnh, mắt hay có xu hướng tập trung chú ý vào 1 vài đơn vị chính. Ảnh trên đấy là 1 minh hoạ, màu đá quý được sử dụng để biểu lộ mức độ thu hút. Saliency prediction là việc mô phỏng sự triệu tập của mắt fan khi quan cạnh bên 1 bức ảnh. Gắng thể, bài bác toán yên cầu xây dựng 1 model, mã sản phẩm này nhận ảnh đầu vào, trả về 1 mask mô phỏng mức độ thu hút. Như vậy, model nhận vào 1 đầu vào image với trả về 1 mask có kích thước tương đương.

Để rõ rộng về câu hỏi này, chúng ta có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataset phổ biến nhất: SALICON DATASET

2.2 Unet

Note: Bạn rất có thể bỏ qua phần này nếu đang biết về Unet

Đây là 1 trong những bài toán Image-to-Image. Để giải quyết và xử lý bài toán này, mình sẽ xây dựng dựng 1 model theo kiến trúc Unet. Unet là 1 kiến trúc được sử dụng nhiều trong câu hỏi Image-to-image như: semantic segmentation, tự động hóa color, super resolution ... Phong cách thiết kế của Unet có điểm tựa như với phong cách xây dựng Encoder-Decoder đối xứng, được thêm những skip connection tự Encode quý phái Decode tương ứng. Về cơ bản, những layer càng tốt càng trích xuất thông tin ở nấc trìu tượng cao, điều ấy đồng nghĩa cùng với việc những thông tin nút trìu tượng rẻ như mặt đường nét, màu sắc, độ phân giải... Sẽ ảnh hưởng mất mát đi trong quá trình lan truyền. Tín đồ ta thêm các skip-connection vào để giải quyết vấn đề này.

Với phần Encode, feature-map được downscale bằng các Convolution. Ngược lại, ở đoạn decode, feature-map được upscale bởi những Upsampling layer, trong bài xích này mình sử dụng những Convolution Transpose.

*

2.3 Resnet

Để xử lý bài toán, mình sẽ xây dựng model Unet cùng với backbone là Resnet50. Các bạn nên tò mò về Resnet nếu không biết về phong cách thiết kế này. Hãy quan giáp hình minh hoạ dưới đây. Resnet50 được phân thành các khối lớn . Unet được xây dừng với Encoder là Resnet50. Ta sẽ lấy ra output của từng khối, tạo những skip-connection kết nối từ Encoder sang trọng Decoder. Decoder được thiết kế bởi các Convolution Transpose layer (xen kẽ trong đó là những lớp Convolution nhằm mục đích sút số chanel của feature bản đồ -> giảm con số weight đến model).

Theo quan điểm cá nhân, pytorch rất giản đơn code, dễ nắm bắt hơn tương đối nhiều so với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, vấn đề fine-tuning model trong pytorch lại cạnh tranh hơn không hề ít so với Keras. Trong Keras, ta không buộc phải quá thân thiết tới loài kiến trúc, luồng cách xử lý của model, chỉ việc lấy ra các output tại 1 số ít layer nhất thiết làm skip-connection, ghép nối và tạo ra ra mã sản phẩm mới.

*

Trong pytorch thì ngược lại, bạn cần hiểu được luồng xử trí và copy code các layer mong muốn giữ lại trong model mới. Hình bên trên là code của resnet50 trong torchvision. Chúng ta có thể tham khảo link: torchvision-resnet50. Bởi vậy khi gây ra Unet như phong cách xây dựng đã mô tả mặt trên, ta cần đảm bảo đoạn code từ bỏ Conv1 -> Layer4 không biến thành thay đổi. Hãy tham khảo phần tiếp theo để làm rõ hơn.

3. Code

Tất cả code của bản thân mình được gói gọn trong file notebook Salicon_main.ipynb. Bạn cũng có thể tải về và run code theo link github: github/trungthanhnguyen0502 . Trong bài viết mình sẽ chỉ gửi ra đông đảo đoạn code chính.

Import những package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, dữ liệu có thiết bị tự dimension không giống với Keras/TF/numpy. Thông thường với numpy giỏi keras, ảnh có dimension theo sản phẩm tự (batchsize,h,w,chanel)(batchsize, h, w, chanel)(batchsize,h,w,chanel). Thứ tự vào Pytorch trái lại là (batchsize,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng 2 hàm toTensor với toNumpy để thay đổi qua lại thân hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv & Deconv

Mình sẽ xây dựng 2 function trả về module Convolution cùng Convolution Transpose (Deconv)

def Deconv(n_input, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = < Tconv, nn.BatchNorm2d(n_output), nn.LeakyReLU(inplace=True), > return nn.Sequential(*block) def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = < conv, nn.BatchNorm2d(n_output), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout) > return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta đã copy các layer nên giữ tự resnet50 vào unet. Sau đó khởi tạo những Conv / Deconv layer và các layer cần thiết.

Forward function: cần bảo vệ luồng cách xử trí của resnet50 được giữ nguyên giống code nơi bắt đầu (trừ Fully-connected layer). Sau đó ta ghép nối các layer lại theo bản vẽ xây dựng Unet đã thể hiện trong phần 2.

Tạo model: phải load resnet50 với truyền vào Unet. Đừng quên Freeze những layer của resnet50 vào Unet.

Xem thêm: Hướng Dẫn Cách Làm Bóng Đá Giấy, Cách Để Gấp Bóng Origami

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet to lớn make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use lớn reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)model.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataset trả dấn 1 list những image_path cùng mask_dir, trả về image và mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/")<-1>.split(".")<0> mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample<"image"> mask = sample<"mask"> # to Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transform = A.Compose(< A.Resize(width=256,height=256, p=1), A.RandomSizedCrop(<240,256>, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataset = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)# thử nghiệm datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)<:,:,0>img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài bác toán đơn giản và làm cho dễ hiểu, mình đang train theo cách đơn giản dễ dàng nhất, không validate trong qúa trình train nhưng chỉ lưu model sau 1 số epoch độc nhất định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = "model"os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce="mean")for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 demo model

img_fns = glob("./Salicon_dataset/image/val/*.jpg")mask_dir = "./Salicon_dataset/mask/val"val_transform = A.Compose(< A.Resize(width=256,height=256, p=1), A.HorizontalFlip(p=0.5),>)model.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết trái thu được:

*

Đây là bài xích toán dễ dàng nên mình chú ý vào quy trình và phương pháp fine tuning vào pytorch rộng là đi sâu vào giải quyết bài toán. Cảm ơn chúng ta đã đọc

4. Reference

Dataset: salicon.net

Code bài xích viết: https://github.com/trungthanhnguyen0502/-paydayloanssqa.com-Visual-Saliency-prediction

Resnet50 torchvision code: torchvision-resnet

Bài viết cùng chủ đề Visual saliency: Visual Saliency Prediction with Contextual Encoder-Decoder Network!

Theo dõi các nội dung bài viết chuyên sâu về AI/Deep learning tại: Vietnam AI link Sharing Community