葛鈺峣,郭昱汝,周楨洋
(北方工業大學,北京 100144)
計算機視覺在人工智能里可以類比于人類的眼睛,是在感知層上最為重要的核心技術之一[1]。目前,計算機視覺的主要應用場景包括面部識別和檢測,車牌閱讀,照片編輯,高級機器人視覺,光學字符識別等[2]。深度學習框架最初主要用于深度學習的科研工作,代表性開源深度學習框架包括Theano,Caffe等[3]。近年來,國外的大型科技公司開始紛紛布局這一領域并開源,影響力比較大的包括谷歌的TensorFlow和臉書PyTorch。飛槳是國內首個也是當前唯一一 個功能完備的開源深度學習平臺,基于百度在深度學習領域的長期積累自主研發,并在2016年開源[4]。為了鼓勵開發者了解與參與深度學習開源項目,百度飛槳主辦了2021 PaddlePaddle Hackathon飛槳黑客馬拉松活動,吸引了來自全球的開發者。本單測項目為參與百度黑客馬拉松參賽項目。本文涉及的API是基于百度飛槳的paddle.nn.UpsamplingNearest2D API,該API的最近鄰插值圖像處理方法主要應用于圖像大小的調整。本文介紹了paddle.nn.UpsamplingNearest2D的作用及原理,并給出了部分復現和測試的代碼。
最近臨插值算法是利用已知圖像中的像素點來填充新圖像對應位置的像素,基于這種方法可以實現對原圖的放大或者縮小。由于其本質是對原圖像像素的復制和抽樣,故不會產生新的像素,其新圖像的每個像素都是基于其對應的原圖像附近的像素而生成的。
原圖片和縮放后圖片的寬高分別是SW,SH,DW,DH縮放原理如圖1所示。

圖1 縮放圖示
設原圖像的寬和高分別為W_1和H_1,縮放后圖像的寬和高為W_2和H_2,水平和垂直比例scale_factor的計算公式如式(1)所示。

目標圖像中的點(x,y)對應變換尺寸前的圖像坐標為(x0,y0),其中:x0=int(x*scale_factor[0]),y0=int(y*scale_factor[1])[5]
函數包含4個輸入參數一個輸出。4個輸入分別為4D張量代表原圖像,縮放倍數,縮放后圖像的長寬方向的像素數,輸入的數據格式。輸出為4D張量代表調整后圖像。本文復現的函數的輸入參數中img為4D張量的原圖像,scale_fator為縮放倍數,size為縮放后的圖像尺寸,data_formate為輸入數據格式。其中針對數據格式為NCHW(num_batches,channels,height,width)或者NHWC(num_batches,height,width,channels),默認值:'NCHW'。
下面展示復現后代函數碼。
#篇幅關系,僅以數據格式為NCHW為例:

本次測試使用的測試工具為pytest,測試內容包含參數覆蓋、數據類型覆蓋、異常輸入等方面,下面展示前兩部分。
首先對API的參數覆蓋進行測試,針對使用不同的輸入參數size和scale_factor進行分別測試。通過使用兩種不同的表達方式對圖像進行調整來測試此API。

經過對參數覆蓋測試,測試結果無誤,證明該API正確性和穩定性可以得到保證。


經過對數據類型測試,測試結果無誤,證明該API正確性和穩定性可以得到保證。
API說明:
輸入為4-D Tensor時形狀為(num_batches,channels,in_h,in_w)或者(num_batches,in_h,in_w,channels),調整大小只適用于高度和寬度對應的維度。
參數說明:
(1)size-輸出Tensor,輸入為4D張量,形狀為(out_h,out_w)的2-D Tensor。如果size是列表,每一個元素可以是整數或者形狀為[1]的變量。如果size是變量,則其維度大小為1。默認值為None。
(2)scale_factor輸入的高度或寬度的乘數因子。size和scale_factor至少要設置一個。size的優先級高于scale_factor。默認值為None。如果scale_factor是一個list或tuple,其必須與輸入的shape匹配。
(3)data_format指定輸入的數據格式,輸出的數據格式將與輸入保持一致。對于4-D Tensor,支持NCHW(num_batches,channels,height,width)或 者NHWC(num_batches,height,width,channels),默認值:'NCHW'。
本文從參數覆蓋、正確性驗證、數據類型覆蓋、異常輸入等方面對飛槳框架中paddle.nn.UpsamplingNearest2D進行了復現和單元測試,并未發現問題。保證百度飛槳框架paddle.nn.UpsamplingNearest2D的正確性與穩定性,完整的代碼因篇幅關系在此不做贅述。單測相關代碼已在Github網站PaddlePaddle/PaddleTest/[PaddlePaddle hackath-on]add UpsamplingNearest2D unittest下開源。