Capsule Network

Tản mạng về CNN (Văn học thì có tản văn, Toán thì có tản toán, mình làm về mạng neural sẽ là tản mạng chứ không phải là tản mạn ! ). Bài viết này sẽ đưa các bạn từ CNN đến Capsule Network.

Để có thể hiểu được khái về mạng Capsule Network mình sẽ quay về tiền thân của nó là mạng CNN.

1. Convolutional Neural Network (CNN)

Ảnh 1: Minh họa trong các bài toán xử lý ảnh

Mạng CNN hay còn được biết đến với tên là Mạng Nơ-rơn Tích Chập. Với đầu vào của mạng sẽ là một bức ảnh trải qua n lớp ẩn đầu ra sẽ là nhãn của ảnh đó tương ứng với các nhãn đã được huấn luyện trong mô hình.

Ví dụ: (Xem ở Ảnh 2) như khi ta chụp ảnh vật thể, bỏ vào mạng CNN để dự đoán thì kết quả sẽ trả ra vật thể đó sẽ có nhãn là gì (tùy thuộc vào cách setup mạng và số lượng nhãn của bộ dữ liệu mà ta đã huấn luyện trước đó).

Ảnh 2: Chức năng CNN

Cơ chế nhân tích chập hoạt động bằng cách áp dụng một bộ lọc (kernel hoặc filter) lên dữ liệu đầu vào. Bộ lọc này có thể là một ma trận nhỏ với các trọng số tương ứng. Bộ lọc này di chuyển trên dữ liệu đầu vào theo bước nhảy (stride) và thực hiện phép nhân tích chập để tính toán giá trị đầu ra. Lúc này ảnh của ta đã highlight được các đặc trưng có trong ảnh nhưng kích thước đầu vào của ảnh vẫn không thay đổi, để có thể lấy được các đặc trưng vừa có thể giảm được kích thước của ảnh. Ta sẽ cho qua phép Pooling.

Ý tưởng của phép Pooling sẽ là lấy giá trị lớn nhất (Max Pooling) của 1 cụm pixel, hoặc là giá trị trung bình (AVG Pooling) sau khi trải qua phép pooling lúc này ảnh thu được sẽ nhỏ hơn kích thước thước ảnh ban đầu nhưng vẫn giữ được các đặc trưng có bức ảnh.

Ảnh 3: Phép Pooling

Vậy câu hỏi đặt ra là nếu khuôn mặt ở Ảnh 2 bị đảo ngược thì CNN có nhận ra ?

Câu trả lời sẽ là CNN vẫn sẽ trả ra kết quả ảnh này có nhãn là “face”. Vì phép Pooling giúp tạo tính dịch chuyển không đổi từ đó vị trí của các đặc trưng đầu vào không quan trọng nữa. Nghĩa là nếu trong ảnh có đủ các đặc trương thì CNN sẽ trả ra kết quả mà không cần quan tâm đến vị trí của đặc trưng nằm ở đâu trong bức ảnh. Đơn cử như ở Ảnh 4 CNN vẫn trả ra được kết quả là do ảnh đầu vào vẫn đầy đủ các chi tiết của khuôn mặt như là: 2 mắt, 1 miệng, 1 mũi,…

Ảnh 4: Ảnh khuôn mặt đảo ngược

Thoạt nghe tính bất biến này có vẻ “vô hại” nhưng thực chất nó lại phát sinh “hàng tá vấn đề”:  Ví dụ thực tế về tác hại của tính bất biến của các đặc trưng trong mạng CNN có thể liên quan đến phân lớp ảnh.
Giả sử chúng ta có một mạng CNN được huấn luyện để phân loại hình ảnh động vật thành hai nhãn: “mèo” và “chó”. Mạng CNN này có tính bất biến với biến đổi dịch chuyển, tức là nó sẽ nhận diện được mèo hoặc chó dù ảnh của chúng được dịch chuyển một khoảng nhất định.
Tuy nhiên, một tác hại của tính bất biến này là mạng CNN có thể không phân biệt được giữa hai biến thể của cùng một loài động vật, ví dụ: mèo và chó đứng. Mặc dù hình dạng và đặc điểm của chúng có thể khác nhau, nhưng nếu cả hai được dịch chuyển cùng một khoảng, mạng CNN có thể cho cùng một kết quả cho cả hai ảnh, nghĩa là không phân biệt được giữa mèo và chó đứng.

Người được xem là cha đẻ của AI trên Reddit cho rằng “Phép pooling sử dụng trong mạng CNN là một sai lầm lớn và việc nó hoạt động hiệu quả là một thảm họa.” [1]

Từ đó mạng Capsule Network ra đời để giải quyết vấn đề trên.

2. Capsule Network (CapsNet)

Capsule NetWork hay còn được gọi tắt là CapsNet.
Nó được giới thiệu bởi Geoffrey Hinton và đồng nghiệp của ông vào năm 2017 [2] phát triển từ ý tưởng “Transforming Auto-Encoders” [3] với đề xuất thay thế cho mạng CNN trong việc xử lý hình ảnh.

Lý giải cho phát biểu [1]: Sử dụng phép pooling tuy làm giảm kích thước của dữ liệu và có thể lấy ra đặc trưng tốt nhất có thể nhưng lại làm mất đi thông tin quý giá về không gian giữa các đặc trưng khi dữ liệu đi qua các lớp trong mạng.

->Từ đó Hinton cho rằng cần phải bổ sung thông tin về mối quan hệ không gian giữa các đặc trưng

Lúc này các đặc trưng trong ảnh được gọi là một thực thể (Entity) được biểu diễn dưới dạng một vector đặc trưng (feature vector) chứ không phải là một số hoặc một ma trận số như trong mạng CNN (xem ở Ảnh 5, 6).

Ảnh 5: Mỗi đặc trưng được gọi là 1 thực thể (Entity)

Ảnh 6: Mỗi thực thể được biểu diễn dưới dạng vector

Các vector capsule từ lớp dưới được tổng hợp lại để tạo thành một biểu diễn tổng thể của các thực thể trong hình ảnh. Quá trình tổng hợp thông tin này thường được thực hiện bằng cách sử dụng weighted sum (tổng có trọng số) của các vector capsule. Trọng số này được tính toán thông qua quá trình dynamic routing, trong đó đầu ra của các capsule trong lớp dưới được sử dụng để cập nhật trọng số.

Ảnh 7: Tổng hợp các vector đặc trưng để tạo ra một kết quả dự đoán

Lúc này mỗi tế bào nơ-ron trong mạng sẽ không còn là một giá trị như các mạng thông thường ( xem Ảnh 8) mà mỗi tế bào nơ-ron sẽ là một ma trận trọng số (xem Ảnh 9).

Nên mạng này tương đối là nặng khi huấn luyện, chỉ phù hợp với những bộ dữ liệu nhỏ, do đòi đỏi phần cứng tính toán khá cao.

Ảnh 8: Tế bào nơ-ron trong mạng truyền thống

Ảnh 9: Tế bào nơ-ron trong mạng CapsNet

Cũng như mình đã trình bày ở phần trên (xem lại Ảnh 6) mỗi thực thể sẽ được biểu thị dưới dạng vector, nhưng lại sinh ra một vấn đề khác là vector của thực thể này lại quá lớn so với các thực thể khác. Vector này biểu diễn các thuộc tính của các thực thể và có độ lớn (độ dài) khác nhau.

Tuy nhiên, độ lớn của vector không thể đơn thuần được coi là chỉ số quan trọng để đo lường mức độ tồn tại của một đặc trưng. Thay vào đó, một khía cạnh quan trọng trong CapsNet là sự hiện diện của các đặc trưng, được xác định bởi hướng của các vector.

Lúc này vector sẽ được chuẩn hóa bởi công thức như sau (xem Ảnh 10) sau khi chuẩn hóa sẽ cho ra kết quả (xem Ảnh 11).

Ảnh 10: Vector chưa chuẩn hóa

Ảnh 11: Vector sau khi chuẩn hóa

Quá trình định tuyến động giữa các viên nang xảy ra để tính toán trọng số để xác định mức độ tương quan giữa các viên nang trong lớp dưới và các viên nang trong lớp tiếp theo. Đến cuối cùng sẽ vector có giá trị lớn nhất tương ứng với xác suất cao nhất và sẽ được chọn ra để xem vector này sẽ hội tụ về nhãn nào.

-> Mỗi khối capsule cố gắng dự đoán 1 kết quả và gửi nó đến các lớp capsule sau

Ảnh 12: Mỗi khối capsule cố gắng dự đoán 1 kết quả và gửi nó đến các lớp capsule sau

Để có thể trực quan hơn mình sẽ minh họa cách mà Mnist hoạt động trên mạng CapsNet.

3. CapsNet on MNIST

Ảnh 13: Minh họa trên bộ dữ liệu Mnist

Lúc này ảnh đầu vào sẽ là hình chụp chữ số viết tay của con số 7. Trải qua bước trích xuất ta thu được 3 đặc trưng gọi tắt là “Line 1”, “Line 2”, “Line 3”.

Từ 3 đặc trưng này các khối Capsule sẽ cố gắng “lắp ráp” lại và cho ra nhiều kết quả dự đoán sau đó mạng sẽ chọn kết quả vector có giá trị lớn nhất tương ứng với xác suất có độ chính xác cao nhất và tiến hành xem vector này sẽ hội tụ về label nào.

4. Ưu điểm CapsNet trên các bộ dữ liệu Overlap

Nhờ vào việc mỗi khối Capsule có khả năng dự đoán được một kết quả riêng biệt. CapsNet bảo tồn thông tin về mối quan hệ không gian giữa các thực thể trong ảnh. Điều này làm cho CapsNet đáng chú ý trong việc phân loại và nhận diện đối tượng trên các bộ dữ liệu chứa sự chồng chéo, nơi thông tin không gian là quan trọng.

Ảnh 14: Minh họa trên bộ dữ liệu Mnist bị Overlap

*Giải thích ký hiệu trong ảnh 14: R = Nhãn thực tế; L = Nhãn mà mạng CapsNet dự đoán ra.

Tài liệu tham khảo:

[1] https://urlvn.net/wsk6at

[2] Sabour, S., Frosst, N., & Hinton, G. E. (2017). Dynamic routing between capsules. Advances in neural information processing systems, 30.

[3] Hinton, G. E., Krizhevsky, A., & Wang, S. D. (2011). Transforming auto-encoders. In Artificial Neural Networks and Machine Learning–ICANN 2011: 21st International Conference on Artificial Neural Networks, Espoo, Finland, June 14-17, 2011, Proceedings, Part I 21 (pp. 44-51). Springer Berlin Heidelberg.

[4] Link PPT: https://docs.google.com/presentation/d/1T9wnPsnYXzYPn7rgFH88p7d49IIvreoc/edit?usp=sharing&ouid=105615835573785134039&rtpof=true&sd=true

3 Bình luận

  1. Nội dung bài viết hay, đầy đủ các kiến thức cơ bản, dễ tiếp cận kiến thức về Capsule Network.
    Bài viết nên bổ sung chú thích hình ảnh sẽ hoàn thiện hơn

Trả lời

Email của bạn sẽ không được hiển thị công khai. Các trường bắt buộc được đánh dấu *