Shogo's Blog

Oct 15, 2023 - 1 minute read - golang

半精度浮動小数点数をあつかうGoのライブラリを書いた

半精度浮動小数点数をあつかうGoのライブラリを書いてみました。

背景

なぜ書いたかというと、半精度浮動小数点数について勉強するためです。

  • 最近のAIブームでビット数の少ない浮動小数点数が注目されていて興味を持ったため
    • 最近の研究で、有効桁数はそこまで重要でないことがわかってきた
    • パラメーターの数が膨大なので、少しでもモデルを圧縮したい
  • CBORの実装読んでいたら、仕様の一部に半精度浮動小数点数が出てきたため

使い方

FromFloat64で倍精度浮動小数点型から半精度浮動小数点数へ変換できます。

import "github.com/shogo82148/float16"

func main() {
  a := float16.FromFloat64(1.0)
  fmt.Printf("%04x", a.Bits())
}

Float16.Bitsで内部表現を取得できるので、 この結果をシリアライズに使うのが主な使い方になると思います。

一応四則演算も実装してあります。

import "github.com/shogo82148/float16"

func main() {
  a := float16.FromFloat64(1.0)
  b := float16.FromFloat64(2.0)

  fmt.Printf("%f + %f = %f", a.Add(b))
  fmt.Printf("%f - %f = %f", a.Sub(b))
  fmt.Printf("%f * %f = %f", a.Mul(b))
  fmt.Printf("%f / %f = %f", a.Div(b))
}

ただし(自分で書いておいてなんですが)あまり実用性はないです。 というのも半精度浮動小数点数同士の演算結果は float64 型で正確に表現できます。 そのため float64 型で計算したあと半精度浮動小数点数に戻せば、まったく同じ計算ができます。

import "github.com/shogo82148/float16"

func main() {
  a := float16.FromFloat64(1.0).Float64()
  b := float16.FromFloat64(2.0).Float64()

  fmt.Printf("%f + %f = %f", float16.FromFloat64(a + b))
  fmt.Printf("%f - %f = %f", float16.FromFloat64(a - b))
  fmt.Printf("%f * %f = %f", float16.FromFloat64(a * b))
  fmt.Printf("%f / %f = %f", float16.FromFloat64(a / b))
}

AI関連で注目されているのは、専用ハードウェアを作れば回路規模が小さくて済むという利点があるからです。 ソフトウェアエミュレーションではこの点を活かせません。

実装

半精度浮動小数点数と倍精度浮動小数点数の変換

Float16 から float64 への変換は比較的楽です。 Wikipediaにビットの配置が載っているので、そのとおりに再配置すればOKです。

逆、 float64 から Float16 は少し面倒です。 精度が落ちるので適切に丸める必要があります。オーバーフローやアンダーフローが発生するケースも考慮しなければなりません。

四則演算

浮動小数点数の仕様通りにデコードしたあと、頑張って計算します。 計算結果の検証には Berkeley TestFloat を使用しました。

文字列への変換

これがなかなかに面倒・・・たとえば 0.3 にもっとも近い半精度浮動小数点数を愚直に10進数へ変換すると 0.300048828125 となります。 これでは長くて読みにくいので 0.3 にしたいですよね。 でもそこで四捨五入するのか考え出すと、よくわからなくなってきた・・・。

四捨五入のアルゴリズム、Goの標準ライブラリでは Ryū というアルゴリズムを使用しています。

この実装を参考にしようと思ったけど、なかなか複雑でハード・・・。ちゃんとした実装はまだできていません。

ちなみに Ryū は日本語の「龍」から取っているそうです。 こんなところで日本語が使われているとは。

参考