[NMDL] Chính quy hoá tham số (Weight regularization)
Tìm hiểu tham số weight_decay trong các Optimizer thường dùng.
Tưởng tượng bạn đang xây một ngôi nhà. Bạn muốn ngôi nhà thật đẹp, thật vững chắc, nhưng lại không muốn tốn quá nhiều chi phí. Nếu bạn xây quá nhiều phòng, quá nhiều tầng, ngôi nhà sẽ rất dễ bị sập khi gặp gió lớn (overfitting). Để tránh điều này, bạn cần đơn giản hóa ngôi nhà, loại bỏ những chi tiết không cần thiết. Weight regularization cũng hoạt động tương tự như vậy, nó giúp chúng ta xây dựng các mô hình máy học đơn giản hơn, nhưng vẫn đảm bảo độ chính xác.
Weight Regularization (Chính quy hoá trọng số) sẽ điều chỉnh mô hình ngay trong quá trình huấn luyện bằng cách loại bỏ những trường hợp có khả năng cao sẽ làm “sập nhà”. Ví dụ, trong một mạng neural có quá một vài trọng số có giá trị quá lớn. Các trọng số này sẽ dễ dàng bị kích hoạt dù là đầu vào có là gì vì bản thân giá trị của chúng đã quá lớn. Nguyên nhân lớn là trong quá trình huấn luyện, sự dễ kích hoạt (sensitive) này không ảnh hưởng vì dữ liệu huấn luyện bị giới hạn. Nếu không được giải quyết, khi đem mô hình ra thử trên dữ liệu thực tế, khả năng (rất) cao là chúng sẽ khiến cho mô hình hoạt động kém hiệu quả. Một lần nữa, đây chính là overfitting problem.
Ta có thể giải thích vấn đề như phía trên là sự bất thường trong phân phối giá trị của các trọng số sau khi được huấn luyện. Weight Regularization tìm cách mô hình hoá toán học các dấu hiệu trong bộ trọng số như vậy để phạt mô hình trong quá trình huấn luyện. Giả sử, với một mô hình có bột trọng số là w, được tối ưu bằng một hàm loss L, và ta chọn được một hàm số để “regularization” là R(w) (hàm này phạt nặng nếu bộ trọng số có dấu hiệu xấu). Và, cách đơn giản nhất là cộng luôn regularization term này vào hàm loss ban đầu để có được hàm loss mới:
L = L + λ R(w)
Với tham số λ (lambda), chúng ta có thể điều chỉnh mức độ ảnh hưởng của hàm regularisation đến quá trình huấn luyện mô hình.
Khi λ tiến về 0: Hàm regularisation gần như không có tác dụng. Mô hình sẽ tập trung chủ yếu vào việc giảm thiểu hàm mất mát (loss function) trên tập dữ liệu huấn luyện, có thể dẫn đến hiện tượng overfitting nếu mô hình quá phức tạp.Khi λ tiến về vô cùng: Hàm regularisation trở nên quá mạnh, mô hình sẽ ưu tiên giảm thiểu giá trị của hàm regularisation hơn là giảm thiểu hàm mất mát. Điều này có thể khiến mô hình trở nên quá đơn giản và không thể bắt được các đặc trưng phức tạp trong dữ liệu.
Trong bài viết này, chúng ta sẽ tập trung vào hai kỹ thuật regularization phổ biến dựa trên norm là L1 và L2.
Một số ký hiệu được sử dụng:
_m: Chỉ các giá trị mới nhất, được cập nhật sau mỗi lần lặp.
_c: Chỉ các giá trị cũ, trước khi cập nhật.
_b: Chỉ các giá trị ban đầu, được khởi tạo trước khi bắt đầu quá trình huấn luyện.
Định nghĩa về norm
Đơn giản và ngắn gọn, norm của một vector giống như độ dài của một đoạn thẳng. Tuy nhiên, trong không gian nhiều chiều, khái niệm "độ dài" này được tổng quát hóa bằng norm. Để làm việc với Deep Learning, có hai norm cần nhớ là L1 vào L2, tương ứng là norm bậc 1 và norm bậc 2.
Norm L1 (Norm Manhattan):
Định nghĩa: Norm L1 của một vector là tổng giá trị tuyệt đối của các thành phần của vector đó. (Trong một vector, mỗi chiều sẽ có một giá trị x_i)
Công thức: ||x||₁ = |x₁| + |x₂| + ... + |xₙ|
Hình dung: Nếu bạn muốn đi từ gốc tọa độ đến một điểm trong không gian 2 chiều, norm L1 sẽ cho biết tổng quát đường đi của bạn theo các cạnh của các ô vuông.
Norm L2 (Norm Euclidean):
Định nghĩa: Norm L2 của một vector là căn bậc hai của tổng bình phương các thành phần của vector đó.
Công thức: ||x||₂ = √(x₁² + x₂² + ... + xₙ²)
Hình dung: Norm L2 chính là khoảng cách Euclid quen thuộc mà chúng ta thường sử dụng. (Công thức toán cấp 3)
** Khái niệm độ dài được tổng quát hoá và áp dụng cho không gian với số chiều nhiều hơn. Nên “độ dài” được đặt trong ngoặc kép. Còn về các lý thuyết cụ thể hơn, mời bạn đọc tham khảo chương 3 của [1].
![Hình 1. Trên đây là tập hợp các vector (là các điểm trên đường màu đỏ) mà có “độ dài” bằng 1 tương ứng với norm 1 (bên trái) và norm 2 (bên phải). Hình từ [1].](https://images.spiderum.com/sp-images/5ef38e90d0db11efa7ac8f07c761e443.png)
Hình 1. Trên đây là tập hợp các vector (là các điểm trên đường màu đỏ) mà có “độ dài” bằng 1 tương ứng với norm 1 (bên trái) và norm 2 (bên phải). Hình từ [1].
L1 Regularization
Như đã trình bày ở phần mở đầu và phần giới thiệu về norm phía trên. L1 Regularization chỉ đơn giản cộng thêm một lượng R(w) vào trong hàm loss của quá trình huấn luyện, mà ở đó R(w) chính là norm 1, hay gọi chính quy là l1 norm (”eol-one” norm).

Phép chính quy hoá này nếu được áp dụng cùng với các thuật toán học dựa trên gradient sẽ ép rất nhiều trọng số tiến về 0 (hình tượng là tắt kết nối neural đó) và khiến cho bộ trọng số rải theo phân phối Laplace. Tính chất này hay được đề cập đến bằng tên gọi “sparsity” - thưa thớt. Tức là ám chỉ sự thưa các kết nối thực tế sẽ làm việc. Vì các kết nối có trọng số bằng 0 đơn giản là không kết nối. Để giải thích cho hiện tượng này, độc giả chỉ cần thực hiện đạo hàm hàm loss mới như bình thường. Thực tế, sau khi đạo hàm, term l1 norm sẽ chỉ còn lại lambda (hằng số), vì phép đạo hàm theo w và w thì bậc 1. Như vậy, mỗi lần cập nhật trọng số theo hàm loss mới này, ta trừ một lượng nhỏ vào trọng số ban đầu. Dần dần, biến trọng số ban đầu đó thành 0.
Số lượng trọng số bằng 0 thực tế khá lớn. Lý do thì có thể tưởng tượng hình học như sau: quan sát hình bên trái của Hình 1 sẽ thấy tập hợp các vector với giá trị norm cố định thường có hình “diamond”. Tuy nhiên, hàm loss thường gặp lại có hình uốn lượn (ellipsoidal), ví dụ như hàm mean square errors. Khi đó, giao của hai hình này sẽ thường ngay ở góc của hình “diamond”. Và tại đó, hầu hết giá trị sẽ là 0.
Ngoài ra, tính chất “sparsity” có thể được tổng quát hoá và buộc mạng neural “thưa thớt có trật tự” [2]. Tuy nhiên, L1 thuần không được ưa chuộng lắm cho các thuật toán học dựa trên gradient. Bởi vì hàm này không liên tục (do có dấu giá trị tuyệt đối), dẫn đến không thể đạo hàm tại một số điểm. Tuy nhiên, nếu cố chấp thì ta vẫn có thể dùng nó như bình thường bằng một số mẹo tham số hoá như là [3].
Mục tiêu cuối cùng vẫn là làm cho bộ trọng số có vùng giá trị nhỏ lại, hầu hết bằng 0, và phần còn lại khá nhỏ. Đây chính là mục tiêu của Weight Regularization. Tuy nhiên, như đã chỉ ra và các lý do khác, L1 khá hạn chế. Vì thế, hãy thử L2.
L2 Regularization
Mô hình hoá L2 Regularization cũng khá đơn giản, ta chỉ việc thay thế hàm R(w) từ norm 1 trong L1 sang norm 2. Lúc này, hàm R(w) sẽ có công thức như sau:

Đối với cùng một hàm mất mát trước khi thay đổi, việc áp dụng hình phạt L2 norm sẽ ưu tiên các bộ trọng số có độ lớn trọng số thấp hơn, tương ứng với sự thay đổi "ít đột ngột" hơn trong đầu ra cho một độ lệch nhỏ trong đầu vào. Điều này là tượng tự L1. Tuy nhiên, điểm khác nhau chính là sự điều chỉnh trọng số này là dựa trên độ lớn của trọng số. Hãy bắt đầu từ đạo hàm của hàm loss có regularization term. Hàm loss mới ta có được là:

Khi đó, đạo hàm sẽ thu được:

Khi viết dưới dạng này, phương pháp này đôi khi được gọi là suy giảm trọng số (weight decay), bởi vì nếu không có phần tử đầu tiên, hiệu ứng tổng thể của nó là làm suy giảm trọng số theo một hệ số tỷ lệ nhỏ λ (làm cho trọng số giảm dần về 0 theo cấp số nhân trong số lần lặp nếu ∇L = 0). Đối với (S)GD, điều chỉnh L2 và suy giảm trọng số trùng nhau. Tuy nhiên, đối với các loại thuật toán tối ưu hóa khác (ví dụ: SGD dựa trên momentum như Adam), thường áp dụng một bước xử lý hậu kiểm soát trên các gradient. Ký hiệu g(∇L) là gradient đã được xử lý hậu kiểm soát của hàm mất mát (không được điều chỉnh), chúng ta có thể viết một công thức cập nhật trọng số tổng quát (bỏ qua hằng số 2) như sau:

Điều đáng chú ý ở đây là hàm g() này chỉ đóng vai trò hậu xử lý gradient của hàm loss ban đầu chứ không phải cả hàm loss mới. Đây là điểm đặc biệt quan trọng để optimizer như Adam làm việc tốt [4].
Về cái tên tên Weight Decay
Trong các hàm loss phía trên, giả sử ta không có hàm loss nữa mà chỉ có phần Regularization thì chúng biến thành một hàm làm giảm giá trị của trọng số. Nói cách khác, Regularization chính là tìm cách giảm độ lớn của trọng số. Vì thế, rất đơn giản và trực tiếp, chúng còn được gọi là weight decay. Và đây chính là một tham số quan trọng mỗi khi bạn cân nhắc chọn và thiết lập một Optimizer cho quá trình huấn luyện mô hình của mình.
References
[1] M. P. Deisenroth, A. A. Faisal, C. S. Ong. MATHEMATICS FOR MACHINE LEARNING. Published by Cambridge University Press (2020).
[2] Scardapane, Simone, et al. "Group sparse regularization for deep neural networks." Neurocomputing 241 (2017): 81-89.
[3] Ziyin, Liu, and Zihao Wang. "spred: Solving L1 Penalty with SGD." International Conference on Machine Learning. PMLR, 2023.
[4] I. Loshchilov and F. Hutter. Decoupled weight decay regularization. In ICLR, 2019. 146

Khoa học - Công nghệ
/khoa-hoc-cong-nghe
Bài viết nổi bật khác
- Hot nhất
- Mới nhất
Hãy là người đầu tiên bình luận bài viết này