陳凱
用梯度下降的方法來實現線性回歸,是一種很經典的機器學習算法。然而,在基礎教育階段,由于受學生自身數學水平和信息技術水平的限制,他們對這種算法的基本原理以及實現回歸過程的程序代碼的學習和理解,難度還是比較大的。筆者曾經瀏覽過不少線性回歸方面的資料,發現學習路徑頗為陡峭。本篇文章試著從一個小游戲入手,一邊玩一邊想,逐層鋪墊,緩慢進階,希望能讓學習者避免迷失在諸多概念名詞和繁雜程序代碼之中,真正體驗到線性回歸算法的精髓。
● 冰糖葫蘆手工串
先來看游戲的玩法。有幾個糖葫蘆散落在桌面上,要求不改變糖葫蘆的位置,用一根竹簽將它們串起來,游戲一開始,竹簽只是平放在桌面邊緣,并沒有串起任何一個糖葫蘆,如圖1所示。
所謂的桌面,實際上是一個坐標軸,X軸和Y軸的范圍都是0到10;所謂的糖葫蘆,是四個面積比較大的點。實現串糖葫蘆游戲的Python代碼非常簡單,如圖2所示。
坐標軸上的這四個點的markersize參數是40,所以看上去就相當大,可以想見,這個參數越大,糖葫蘆也越大,游戲也越簡單。糖葫蘆所在的坐標直接寫在了plot函數的參數里,為代碼簡單清晰起見,這里并沒有引入隨機數。
所謂的竹簽,是通過y=t1+t2*x這個一元一次方程產生的直線,其中t1是方程的常數,t2是一次項系數。如果輸入t1為0.9,輸入t2為0.6,那么竹簽就串到了兩個糖葫蘆,如上頁圖3所示。
試過幾次就可以發現,t1和t2兩個數字的作用大不相同:前者決定了竹簽的位置,“靠上還是靠下”與“靠左還是靠右”其實是一回事;后者決定了竹簽的傾斜度。如果想讓竹簽往上擺一些,則t1要增加;如果想讓竹簽擺得平一些,則t2要減少。不妨再試一下,輸入t1為3,輸入t2為0.2,竹簽成功地串起了三個糖葫蘆,如上頁圖4所示。
多次嘗試之后,就能體會到t1和t2兩個數字的變化與最終直線形態之間的微妙關系。
● 竹簽離得有多遠
剛才是用人腦來判斷竹簽的位置t1和傾角t2,那么,怎么讓機器判斷位置和傾角呢?方法就是“看了再試,試了再看”。
例如,一開始,竹簽是平躺著的,竹簽到每個點的縱向的距離(為簡單起見,這里暫不考慮橫向的距離)是可以計算出來的,如圖5所示,其中左圖顯然離開理想結果還很遙遠,右圖已經比較接近目標了。
為了能夠計算出竹簽到糖葫蘆的距離,以評估竹簽與理想目標之間的差距,可以將代碼稍微修改一下。由于糖葫蘆可能在竹簽上面,也可能在竹簽下面,在計算距離時,數值可能是正,也可能是負,所以縱向距離統一進行二次方的運算(其實取絕對值也是一樣的)。又因為總共有四個糖葫蘆,所以要除以4,得到一個縱向距離的平均數。修改后的程序代碼如圖6所示。
e = (e1+e2+e3+e4)/4/2這段代碼,就是將糖葫蘆中心點和竹簽的縱向距離的平均值賦值給變量e,其中唯一難理解的是,為什么除以4取得平均值后又要除以2,其實這是為了讓后續的求導公式更加便捷,若是學習者學習過隱藏在線性回歸算法之后的數學原理,那么就能更清楚地知道這里為何要除以2。實際上,因為變量e的值只是用于在運算過程中指示竹簽位置和傾角離開理想值的差距,所以是否除以2,其實是無所謂的。假如說任務是要讓竹簽盡可能靠近糖葫蘆的中心,那么就需要觀察e值是否能收斂于某個值,然而,本文任務只要求把糖葫蘆串起來即可,因此后續的程序代碼中并不需要用到e變量。
上述代碼中,t1和t2的初始值都是0,運行后,發現得到的e的值是9.15227,這個數字顯然太大了,計算結果和人眼的直觀感受是符合的,所以需要改變t1和t2的值使得竹簽距離糖葫蘆更近一些。
● 竹簽需要重新擺
為了讓竹簽有可能串起更多糖葫蘆,就要重新調整竹簽的位置和傾角,其實就是更改t1和t2的值。但計算機怎么知道應該如何調整呢?
首先來看竹簽的位置,竹簽究竟是往上放還是往下放?應該需要調整多少距離?想象一下,如果大部分糖葫蘆都在竹簽上方,那么就要往上挪,反之就是往下挪,離開越遠,挪的距離就越多,程序中涉及的公式如上頁圖7所示。
這個公式是用微積分的方法推演出來的,然而就算是直觀上也是可以理解的,當竹簽平躺著的時候,得到的d1的值是4.18,這表示竹簽要往上方移動4.18個單位,如果d1的數值是負數,則表示要將竹簽往下方移動。
傾角的調整要復雜一些,因為這和每個糖葫蘆的橫向位置有關,涉及的公式如上頁圖8所示。
這同樣是用微積分求導的方法推演出來的,如果不想深入了解相關數學推導過程,那只需要驗證公式的合理性就可以了。當竹簽平躺著的時候,得到的值是23.65975。如果是正數,表示要把竹簽傾角增大,如果是負數,則表示將傾角變小。不過即便是直觀看圖9,也是能感受到公式的意義的,因為離開坐標軸原點越遠,點的坐標值本身的權重也就越大。
然而,按公式計算后,d1和d2這兩個數值都太大了,實際操作時矯枉過正,于是還要對這兩個值再乘上步長系數0.05,此數可大可小,要根據實際運行情況來調整。系數太小,則需要很多次調整才能達成目標;系數太大,則會在理想結果周圍來回跳躍。
最后,將d1和d2的修正值疊加到原先的t1和t2值上,將循環次數增加到200次,每一次循環中,做的都是同樣的事:用當前的t1和t2計算出直線擺放姿態,然后將運行效果比對預期效果后得到調整值d1和d2,再用d1和d2修正t1和t2值,這種方法稱為迭代法。圖10、圖11是完整的代碼和運行結果。
這個例子本身形象直觀,而代碼也比較簡單,若去掉繪圖相關的代碼,核心代碼僅有十多行,其中隱藏了超出學生當前水平的知識點,尤其是數學微積分和線性代數有關的知識技能,但同時又留下了繼續深入學習的路徑指引,為學習者提供了真切的機器學習算法的實踐體驗。