1108 lines
28 KiB
Go
1108 lines
28 KiB
Go
// Package decimal implements an arbitrary precision fixed-point decimal.
|
|
//
|
|
// To use as part of a struct:
|
|
//
|
|
// type Struct struct {
|
|
// Number Decimal
|
|
// }
|
|
//
|
|
// The zero-value of a Decimal is 0, as you would expect.
|
|
//
|
|
// The best way to create a new Decimal is to use decimal.NewFromString, ex:
|
|
//
|
|
// n, err := decimal.NewFromString("-123.4567")
|
|
// n.String() // output: "-123.4567"
|
|
//
|
|
// NOTE: This can "only" represent numbers with a maximum of 2^31 digits
|
|
// after the decimal point.
|
|
package decimal
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"math/big"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// DivisionPrecision is the number of decimal places in the result when it
|
|
// doesn't divide exactly.
|
|
//
|
|
// Example:
|
|
//
|
|
// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
|
|
// d1.String() // output: "0.6666666666666667"
|
|
// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000)
|
|
// d2.String() // output: "0.0000666666666667"
|
|
// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3)
|
|
// d3.String() // output: "6666.6666666666666667"
|
|
// decimal.DivisionPrecision = 3
|
|
// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
|
|
// d4.String() // output: "0.667"
|
|
//
|
|
var DivisionPrecision = 16
|
|
|
|
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
|
|
// be JSON marshaled as a number, instead of as a string.
|
|
// WARNING: this is dangerous for decimals with many digits, since many JSON
|
|
// unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754
|
|
// double-precision floating point numbers, which means you can potentially
|
|
// silently lose precision.
|
|
var MarshalJSONWithoutQuotes = false
|
|
|
|
// Zero constant, to make computations faster.
|
|
var Zero = New(0, 1)
|
|
|
|
// fiveDec used in Cash Rounding
|
|
var fiveDec = New(5, 0)
|
|
|
|
var zeroInt = big.NewInt(0)
|
|
var oneInt = big.NewInt(1)
|
|
var twoInt = big.NewInt(2)
|
|
var fourInt = big.NewInt(4)
|
|
var fiveInt = big.NewInt(5)
|
|
var tenInt = big.NewInt(10)
|
|
var twentyInt = big.NewInt(20)
|
|
|
|
// Decimal represents a fixed-point decimal. It is immutable.
|
|
// number = value * 10 ^ exp
|
|
type Decimal struct {
|
|
value *big.Int
|
|
|
|
// NOTE(vadim): this must be an int32, because we cast it to float64 during
|
|
// calculations. If exp is 64 bit, we might lose precision.
|
|
// If we cared about being able to represent every possible decimal, we
|
|
// could make exp a *big.Int but it would hurt performance and numbers
|
|
// like that are unrealistic.
|
|
exp int32
|
|
}
|
|
|
|
// New returns a new fixed-point decimal, value * 10 ^ exp.
|
|
func New(value int64, exp int32) Decimal {
|
|
return Decimal{
|
|
value: big.NewInt(value),
|
|
exp: exp,
|
|
}
|
|
}
|
|
|
|
// NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp
|
|
func NewFromBigInt(value *big.Int, exp int32) Decimal {
|
|
return Decimal{
|
|
value: big.NewInt(0).Set(value),
|
|
exp: exp,
|
|
}
|
|
}
|
|
|
|
// NewFromString returns a new Decimal from a string representation.
|
|
//
|
|
// Example:
|
|
//
|
|
// d, err := NewFromString("-123.45")
|
|
// d2, err := NewFromString(".0001")
|
|
//
|
|
func NewFromString(value string) (Decimal, error) {
|
|
originalInput := value
|
|
var intString string
|
|
var exp int64
|
|
|
|
// Check if number is using scientific notation
|
|
eIndex := strings.IndexAny(value, "Ee")
|
|
if eIndex != -1 {
|
|
expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32)
|
|
if err != nil {
|
|
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
|
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value)
|
|
}
|
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value)
|
|
}
|
|
value = value[:eIndex]
|
|
exp = expInt
|
|
}
|
|
|
|
parts := strings.Split(value, ".")
|
|
if len(parts) == 1 {
|
|
// There is no decimal point, we can just parse the original string as
|
|
// an int
|
|
intString = value
|
|
} else if len(parts) == 2 {
|
|
// strip the insignificant digits for more accurate comparisons.
|
|
decimalPart := strings.TrimRight(parts[1], "0")
|
|
intString = parts[0] + decimalPart
|
|
expInt := -len(decimalPart)
|
|
exp += int64(expInt)
|
|
} else {
|
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value)
|
|
}
|
|
|
|
dValue := new(big.Int)
|
|
_, ok := dValue.SetString(intString, 10)
|
|
if !ok {
|
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal", value)
|
|
}
|
|
|
|
if exp < math.MinInt32 || exp > math.MaxInt32 {
|
|
// NOTE(vadim): I doubt a string could realistically be this long
|
|
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput)
|
|
}
|
|
|
|
return Decimal{
|
|
value: dValue,
|
|
exp: int32(exp),
|
|
}, nil
|
|
}
|
|
|
|
// RequireFromString returns a new Decimal from a string representation
|
|
// or panics if NewFromString would have returned an error.
|
|
//
|
|
// Example:
|
|
//
|
|
// d := RequireFromString("-123.45")
|
|
// d2 := RequireFromString(".0001")
|
|
//
|
|
func RequireFromString(value string) Decimal {
|
|
dec, err := NewFromString(value)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return dec
|
|
}
|
|
|
|
// NewFromFloat converts a float64 to Decimal.
|
|
//
|
|
// Example:
|
|
//
|
|
// NewFromFloat(123.45678901234567).String() // output: "123.4567890123456"
|
|
// NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001"
|
|
//
|
|
// NOTE: some float64 numbers can take up about 300 bytes of memory in decimal representation.
|
|
// Consider using NewFromFloatWithExponent if space is more important than precision.
|
|
//
|
|
// NOTE: this will panic on NaN, +/-inf
|
|
func NewFromFloat(value float64) Decimal {
|
|
return NewFromFloatWithExponent(value, math.MinInt32)
|
|
}
|
|
|
|
// NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary
|
|
// number of fractional digits.
|
|
//
|
|
// Example:
|
|
//
|
|
// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46"
|
|
//
|
|
func NewFromFloatWithExponent(value float64, exp int32) Decimal {
|
|
if math.IsNaN(value) || math.IsInf(value, 0) {
|
|
panic(fmt.Sprintf("Cannot create a Decimal from %v", value))
|
|
}
|
|
|
|
bits := math.Float64bits(value)
|
|
mant := bits & (1<<52 - 1)
|
|
exp2 := int32((bits >> 52) & (1<<11 - 1))
|
|
sign := bits >> 63
|
|
|
|
if exp2 == 0 {
|
|
// specials
|
|
if mant == 0 {
|
|
return Decimal{}
|
|
} else {
|
|
// subnormal
|
|
exp2++
|
|
}
|
|
} else {
|
|
// normal
|
|
mant |= 1 << 52
|
|
}
|
|
|
|
exp2 -= 1023 + 52
|
|
|
|
// normalizing base-2 values
|
|
for mant&1 == 0 {
|
|
mant = mant >> 1
|
|
exp2++
|
|
}
|
|
|
|
// maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0
|
|
if exp < 0 && exp < exp2 {
|
|
if exp2 < 0 {
|
|
exp = exp2
|
|
} else {
|
|
exp = 0
|
|
}
|
|
}
|
|
|
|
// representing 10^M * 2^N as 5^M * 2^(M+N)
|
|
exp2 -= exp
|
|
|
|
temp := big.NewInt(1)
|
|
dMant := big.NewInt(int64(mant))
|
|
|
|
// applying 5^M
|
|
if exp > 0 {
|
|
temp = temp.SetInt64(int64(exp))
|
|
temp = temp.Exp(fiveInt, temp, nil)
|
|
} else if exp < 0 {
|
|
temp = temp.SetInt64(-int64(exp))
|
|
temp = temp.Exp(fiveInt, temp, nil)
|
|
dMant = dMant.Mul(dMant, temp)
|
|
temp = temp.SetUint64(1)
|
|
}
|
|
|
|
// applying 2^(M+N)
|
|
if exp2 > 0 {
|
|
dMant = dMant.Lsh(dMant, uint(exp2))
|
|
} else if exp2 < 0 {
|
|
temp = temp.Lsh(temp, uint(-exp2))
|
|
}
|
|
|
|
// rounding and downscaling
|
|
if exp > 0 || exp2 < 0 {
|
|
halfDown := new(big.Int).Rsh(temp, 1)
|
|
dMant = dMant.Add(dMant, halfDown)
|
|
dMant = dMant.Quo(dMant, temp)
|
|
}
|
|
|
|
if sign == 1 {
|
|
dMant = dMant.Neg(dMant)
|
|
}
|
|
|
|
return Decimal{
|
|
value: dMant,
|
|
exp: exp,
|
|
}
|
|
}
|
|
|
|
// rescale returns a rescaled version of the decimal. Returned
|
|
// decimal may be less precise if the given exponent is bigger
|
|
// than the initial exponent of the Decimal.
|
|
// NOTE: this will truncate, NOT round
|
|
//
|
|
// Example:
|
|
//
|
|
// d := New(12345, -4)
|
|
// d2 := d.rescale(-1)
|
|
// d3 := d2.rescale(-4)
|
|
// println(d1)
|
|
// println(d2)
|
|
// println(d3)
|
|
//
|
|
// Output:
|
|
//
|
|
// 1.2345
|
|
// 1.2
|
|
// 1.2000
|
|
//
|
|
func (d Decimal) rescale(exp int32) Decimal {
|
|
d.ensureInitialized()
|
|
// NOTE(vadim): must convert exps to float64 before - to prevent overflow
|
|
diff := math.Abs(float64(exp) - float64(d.exp))
|
|
value := new(big.Int).Set(d.value)
|
|
|
|
expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil)
|
|
if exp > d.exp {
|
|
value = value.Quo(value, expScale)
|
|
} else if exp < d.exp {
|
|
value = value.Mul(value, expScale)
|
|
}
|
|
|
|
return Decimal{
|
|
value: value,
|
|
exp: exp,
|
|
}
|
|
}
|
|
|
|
// Abs returns the absolute value of the decimal.
|
|
func (d Decimal) Abs() Decimal {
|
|
d.ensureInitialized()
|
|
d2Value := new(big.Int).Abs(d.value)
|
|
return Decimal{
|
|
value: d2Value,
|
|
exp: d.exp,
|
|
}
|
|
}
|
|
|
|
// Add returns d + d2.
|
|
func (d Decimal) Add(d2 Decimal) Decimal {
|
|
baseScale := min(d.exp, d2.exp)
|
|
rd := d.rescale(baseScale)
|
|
rd2 := d2.rescale(baseScale)
|
|
|
|
d3Value := new(big.Int).Add(rd.value, rd2.value)
|
|
return Decimal{
|
|
value: d3Value,
|
|
exp: baseScale,
|
|
}
|
|
}
|
|
|
|
// Sub returns d - d2.
|
|
func (d Decimal) Sub(d2 Decimal) Decimal {
|
|
baseScale := min(d.exp, d2.exp)
|
|
rd := d.rescale(baseScale)
|
|
rd2 := d2.rescale(baseScale)
|
|
|
|
d3Value := new(big.Int).Sub(rd.value, rd2.value)
|
|
return Decimal{
|
|
value: d3Value,
|
|
exp: baseScale,
|
|
}
|
|
}
|
|
|
|
// Neg returns -d.
|
|
func (d Decimal) Neg() Decimal {
|
|
d.ensureInitialized()
|
|
val := new(big.Int).Neg(d.value)
|
|
return Decimal{
|
|
value: val,
|
|
exp: d.exp,
|
|
}
|
|
}
|
|
|
|
// Mul returns d * d2.
|
|
func (d Decimal) Mul(d2 Decimal) Decimal {
|
|
d.ensureInitialized()
|
|
d2.ensureInitialized()
|
|
|
|
expInt64 := int64(d.exp) + int64(d2.exp)
|
|
if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 {
|
|
// NOTE(vadim): better to panic than give incorrect results, as
|
|
// Decimals are usually used for money
|
|
panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64))
|
|
}
|
|
|
|
d3Value := new(big.Int).Mul(d.value, d2.value)
|
|
return Decimal{
|
|
value: d3Value,
|
|
exp: int32(expInt64),
|
|
}
|
|
}
|
|
|
|
// Div returns d / d2. If it doesn't divide exactly, the result will have
|
|
// DivisionPrecision digits after the decimal point.
|
|
func (d Decimal) Div(d2 Decimal) Decimal {
|
|
return d.DivRound(d2, int32(DivisionPrecision))
|
|
}
|
|
|
|
// QuoRem does divsion with remainder
|
|
// d.QuoRem(d2,precision) returns quotient q and remainder r such that
|
|
// d = d2 * q + r, q an integer multiple of 10^(-precision)
|
|
// 0 <= r < abs(d2) * 10 ^(-precision) if d>=0
|
|
// 0 >= r > -abs(d2) * 10 ^(-precision) if d<0
|
|
// Note that precision<0 is allowed as input.
|
|
func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) {
|
|
d.ensureInitialized()
|
|
d2.ensureInitialized()
|
|
if d2.value.Sign() == 0 {
|
|
panic("decimal division by 0")
|
|
}
|
|
scale := -precision
|
|
e := int64(d.exp - d2.exp - scale)
|
|
if e > math.MaxInt32 || e < math.MinInt32 {
|
|
panic("overflow in decimal QuoRem")
|
|
}
|
|
var aa, bb, expo big.Int
|
|
var scalerest int32
|
|
// d = a 10^ea
|
|
// d2 = b 10^eb
|
|
if e < 0 {
|
|
aa = *d.value
|
|
expo.SetInt64(-e)
|
|
bb.Exp(tenInt, &expo, nil)
|
|
bb.Mul(d2.value, &bb)
|
|
scalerest = d.exp
|
|
// now aa = a
|
|
// bb = b 10^(scale + eb - ea)
|
|
} else {
|
|
expo.SetInt64(e)
|
|
aa.Exp(tenInt, &expo, nil)
|
|
aa.Mul(d.value, &aa)
|
|
bb = *d2.value
|
|
scalerest = scale + d2.exp
|
|
// now aa = a ^ (ea - eb - scale)
|
|
// bb = b
|
|
}
|
|
var q, r big.Int
|
|
q.QuoRem(&aa, &bb, &r)
|
|
dq := Decimal{value: &q, exp: scale}
|
|
dr := Decimal{value: &r, exp: scalerest}
|
|
return dq, dr
|
|
}
|
|
|
|
// DivRound divides and rounds to a given precision
|
|
// i.e. to an integer multiple of 10^(-precision)
|
|
// for a positive quotient digit 5 is rounded up, away from 0
|
|
// if the quotient is negative then digit 5 is rounded down, away from 0
|
|
// Note that precision<0 is allowed as input.
|
|
func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal {
|
|
// QuoRem already checks initialization
|
|
q, r := d.QuoRem(d2, precision)
|
|
// the actual rounding decision is based on comparing r*10^precision and d2/2
|
|
// instead compare 2 r 10 ^precision and d2
|
|
var rv2 big.Int
|
|
rv2.Abs(r.value)
|
|
rv2.Lsh(&rv2, 1)
|
|
// now rv2 = abs(r.value) * 2
|
|
r2 := Decimal{value: &rv2, exp: r.exp + precision}
|
|
// r2 is now 2 * r * 10 ^ precision
|
|
var c = r2.Cmp(d2.Abs())
|
|
|
|
if c < 0 {
|
|
return q
|
|
}
|
|
|
|
if d.value.Sign()*d2.value.Sign() < 0 {
|
|
return q.Sub(New(1, -precision))
|
|
}
|
|
|
|
return q.Add(New(1, -precision))
|
|
}
|
|
|
|
// Mod returns d % d2.
|
|
func (d Decimal) Mod(d2 Decimal) Decimal {
|
|
quo := d.Div(d2).Truncate(0)
|
|
return d.Sub(d2.Mul(quo))
|
|
}
|
|
|
|
// Pow returns d to the power d2
|
|
func (d Decimal) Pow(d2 Decimal) Decimal {
|
|
var temp Decimal
|
|
if d2.IntPart() == 0 {
|
|
return NewFromFloat(1)
|
|
}
|
|
temp = d.Pow(d2.Div(NewFromFloat(2)))
|
|
if d2.IntPart()%2 == 0 {
|
|
return temp.Mul(temp)
|
|
}
|
|
if d2.IntPart() > 0 {
|
|
return temp.Mul(temp).Mul(d)
|
|
}
|
|
return temp.Mul(temp).Div(d)
|
|
}
|
|
|
|
// Cmp compares the numbers represented by d and d2 and returns:
|
|
//
|
|
// -1 if d < d2
|
|
// 0 if d == d2
|
|
// +1 if d > d2
|
|
//
|
|
func (d Decimal) Cmp(d2 Decimal) int {
|
|
d.ensureInitialized()
|
|
d2.ensureInitialized()
|
|
|
|
if d.exp == d2.exp {
|
|
return d.value.Cmp(d2.value)
|
|
}
|
|
|
|
baseExp := min(d.exp, d2.exp)
|
|
rd := d.rescale(baseExp)
|
|
rd2 := d2.rescale(baseExp)
|
|
|
|
return rd.value.Cmp(rd2.value)
|
|
}
|
|
|
|
// Equal returns whether the numbers represented by d and d2 are equal.
|
|
func (d Decimal) Equal(d2 Decimal) bool {
|
|
return d.Cmp(d2) == 0
|
|
}
|
|
|
|
// Equals is deprecated, please use Equal method instead
|
|
func (d Decimal) Equals(d2 Decimal) bool {
|
|
return d.Equal(d2)
|
|
}
|
|
|
|
// GreaterThan (GT) returns true when d is greater than d2.
|
|
func (d Decimal) GreaterThan(d2 Decimal) bool {
|
|
return d.Cmp(d2) == 1
|
|
}
|
|
|
|
// GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2.
|
|
func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool {
|
|
cmp := d.Cmp(d2)
|
|
return cmp == 1 || cmp == 0
|
|
}
|
|
|
|
// LessThan (LT) returns true when d is less than d2.
|
|
func (d Decimal) LessThan(d2 Decimal) bool {
|
|
return d.Cmp(d2) == -1
|
|
}
|
|
|
|
// LessThanOrEqual (LTE) returns true when d is less than or equal to d2.
|
|
func (d Decimal) LessThanOrEqual(d2 Decimal) bool {
|
|
cmp := d.Cmp(d2)
|
|
return cmp == -1 || cmp == 0
|
|
}
|
|
|
|
// Sign returns:
|
|
//
|
|
// -1 if d < 0
|
|
// 0 if d == 0
|
|
// +1 if d > 0
|
|
//
|
|
func (d Decimal) Sign() int {
|
|
if d.value == nil {
|
|
return 0
|
|
}
|
|
return d.value.Sign()
|
|
}
|
|
|
|
// Exponent returns the exponent, or scale component of the decimal.
|
|
func (d Decimal) Exponent() int32 {
|
|
return d.exp
|
|
}
|
|
|
|
// Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent()
|
|
func (d Decimal) Coefficient() *big.Int {
|
|
// we copy the coefficient so that mutating the result does not mutate the
|
|
// Decimal.
|
|
return big.NewInt(0).Set(d.value)
|
|
}
|
|
|
|
// IntPart returns the integer component of the decimal.
|
|
func (d Decimal) IntPart() int64 {
|
|
scaledD := d.rescale(0)
|
|
return scaledD.value.Int64()
|
|
}
|
|
|
|
// Rat returns a rational number representation of the decimal.
|
|
func (d Decimal) Rat() *big.Rat {
|
|
d.ensureInitialized()
|
|
if d.exp <= 0 {
|
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
|
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
|
|
return new(big.Rat).SetFrac(d.value, denom)
|
|
}
|
|
|
|
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
|
|
num := new(big.Int).Mul(d.value, mul)
|
|
return new(big.Rat).SetFrac(num, oneInt)
|
|
}
|
|
|
|
// Float64 returns the nearest float64 value for d and a bool indicating
|
|
// whether f represents d exactly.
|
|
// For more details, see the documentation for big.Rat.Float64
|
|
func (d Decimal) Float64() (f float64, exact bool) {
|
|
return d.Rat().Float64()
|
|
}
|
|
|
|
// String returns the string representation of the decimal
|
|
// with the fixed point.
|
|
//
|
|
// Example:
|
|
//
|
|
// d := New(-12345, -3)
|
|
// println(d.String())
|
|
//
|
|
// Output:
|
|
//
|
|
// -12.345
|
|
//
|
|
func (d Decimal) String() string {
|
|
return d.string(true)
|
|
}
|
|
|
|
// StringFixed returns a rounded fixed-point string with places digits after
|
|
// the decimal point.
|
|
//
|
|
// Example:
|
|
//
|
|
// NewFromFloat(0).StringFixed(2) // output: "0.00"
|
|
// NewFromFloat(0).StringFixed(0) // output: "0"
|
|
// NewFromFloat(5.45).StringFixed(0) // output: "5"
|
|
// NewFromFloat(5.45).StringFixed(1) // output: "5.5"
|
|
// NewFromFloat(5.45).StringFixed(2) // output: "5.45"
|
|
// NewFromFloat(5.45).StringFixed(3) // output: "5.450"
|
|
// NewFromFloat(545).StringFixed(-1) // output: "550"
|
|
//
|
|
func (d Decimal) StringFixed(places int32) string {
|
|
rounded := d.Round(places)
|
|
return rounded.string(false)
|
|
}
|
|
|
|
// StringFixedBank returns a banker rounded fixed-point string with places digits
|
|
// after the decimal point.
|
|
//
|
|
// Example:
|
|
//
|
|
// NewFromFloat(0).StringFixed(2) // output: "0.00"
|
|
// NewFromFloat(0).StringFixed(0) // output: "0"
|
|
// NewFromFloat(5.45).StringFixed(0) // output: "5"
|
|
// NewFromFloat(5.45).StringFixed(1) // output: "5.4"
|
|
// NewFromFloat(5.45).StringFixed(2) // output: "5.45"
|
|
// NewFromFloat(5.45).StringFixed(3) // output: "5.450"
|
|
// NewFromFloat(545).StringFixed(-1) // output: "550"
|
|
//
|
|
func (d Decimal) StringFixedBank(places int32) string {
|
|
rounded := d.RoundBank(places)
|
|
return rounded.string(false)
|
|
}
|
|
|
|
// StringFixedCash returns a Swedish/Cash rounded fixed-point string. For
|
|
// more details see the documentation at function RoundCash.
|
|
func (d Decimal) StringFixedCash(interval uint8) string {
|
|
rounded := d.RoundCash(interval)
|
|
return rounded.string(false)
|
|
}
|
|
|
|
// Round rounds the decimal to places decimal places.
|
|
// If places < 0, it will round the integer part to the nearest 10^(-places).
|
|
//
|
|
// Example:
|
|
//
|
|
// NewFromFloat(5.45).Round(1).String() // output: "5.5"
|
|
// NewFromFloat(545).Round(-1).String() // output: "550"
|
|
//
|
|
func (d Decimal) Round(places int32) Decimal {
|
|
// truncate to places + 1
|
|
ret := d.rescale(-places - 1)
|
|
|
|
// add sign(d) * 0.5
|
|
if ret.value.Sign() < 0 {
|
|
ret.value.Sub(ret.value, fiveInt)
|
|
} else {
|
|
ret.value.Add(ret.value, fiveInt)
|
|
}
|
|
|
|
// floor for positive numbers, ceil for negative numbers
|
|
_, m := ret.value.DivMod(ret.value, tenInt, new(big.Int))
|
|
ret.exp++
|
|
if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 {
|
|
ret.value.Add(ret.value, oneInt)
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
// RoundBank rounds the decimal to places decimal places.
|
|
// If the final digit to round is equidistant from the nearest two integers the
|
|
// rounded value is taken as the even number
|
|
//
|
|
// If places < 0, it will round the integer part to the nearest 10^(-places).
|
|
//
|
|
// Examples:
|
|
//
|
|
// NewFromFloat(5.45).Round(1).String() // output: "5.4"
|
|
// NewFromFloat(545).Round(-1).String() // output: "540"
|
|
// NewFromFloat(5.46).Round(1).String() // output: "5.5"
|
|
// NewFromFloat(546).Round(-1).String() // output: "550"
|
|
// NewFromFloat(5.55).Round(1).String() // output: "5.6"
|
|
// NewFromFloat(555).Round(-1).String() // output: "560"
|
|
//
|
|
func (d Decimal) RoundBank(places int32) Decimal {
|
|
|
|
round := d.Round(places)
|
|
remainder := d.Sub(round).Abs()
|
|
|
|
half := New(5, -places-1)
|
|
if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 {
|
|
if round.value.Sign() < 0 {
|
|
round.value.Add(round.value, oneInt)
|
|
} else {
|
|
round.value.Sub(round.value, oneInt)
|
|
}
|
|
}
|
|
|
|
return round
|
|
}
|
|
|
|
// RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific
|
|
// interval. The amount payable for a cash transaction is rounded to the nearest
|
|
// multiple of the minimum currency unit available. The following intervals are
|
|
// available: 5, 10, 15, 25, 50 and 100; any other number throws a panic.
|
|
// 5: 5 cent rounding 3.43 => 3.45
|
|
// 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up)
|
|
// 15: 10 cent rounding 3.45 => 3.40 (5 gets rounded down)
|
|
// 25: 25 cent rounding 3.41 => 3.50
|
|
// 50: 50 cent rounding 3.75 => 4.00
|
|
// 100: 100 cent rounding 3.50 => 4.00
|
|
// For more details: https://en.wikipedia.org/wiki/Cash_rounding
|
|
func (d Decimal) RoundCash(interval uint8) Decimal {
|
|
var iVal *big.Int
|
|
switch interval {
|
|
case 5:
|
|
iVal = twentyInt
|
|
case 10:
|
|
iVal = tenInt
|
|
case 15:
|
|
if d.exp < 0 {
|
|
// TODO: optimize and reduce allocations
|
|
orgExp := d.exp
|
|
dOne := New(10^-int64(orgExp), orgExp)
|
|
d2 := d
|
|
d2.exp = 0
|
|
if d2.Mod(fiveDec).Equal(Zero) {
|
|
d2.exp = orgExp
|
|
d2 = d2.Sub(dOne)
|
|
d = d2
|
|
}
|
|
}
|
|
iVal = tenInt
|
|
case 25:
|
|
iVal = fourInt
|
|
case 50:
|
|
iVal = twoInt
|
|
case 100:
|
|
iVal = oneInt
|
|
default:
|
|
panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 15, 25, 50, 100", interval))
|
|
}
|
|
dVal := Decimal{
|
|
value: iVal,
|
|
}
|
|
// TODO: optimize those calculations to reduce the high allocations (~29 allocs).
|
|
return d.Mul(dVal).Round(0).Div(dVal).Truncate(2)
|
|
}
|
|
|
|
// Floor returns the nearest integer value less than or equal to d.
|
|
func (d Decimal) Floor() Decimal {
|
|
d.ensureInitialized()
|
|
|
|
if d.exp >= 0 {
|
|
return d
|
|
}
|
|
|
|
exp := big.NewInt(10)
|
|
|
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
|
|
|
z := new(big.Int).Div(d.value, exp)
|
|
return Decimal{value: z, exp: 0}
|
|
}
|
|
|
|
// Ceil returns the nearest integer value greater than or equal to d.
|
|
func (d Decimal) Ceil() Decimal {
|
|
d.ensureInitialized()
|
|
|
|
if d.exp >= 0 {
|
|
return d
|
|
}
|
|
|
|
exp := big.NewInt(10)
|
|
|
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
|
|
|
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
|
|
if m.Cmp(zeroInt) != 0 {
|
|
z.Add(z, oneInt)
|
|
}
|
|
return Decimal{value: z, exp: 0}
|
|
}
|
|
|
|
// Truncate truncates off digits from the number, without rounding.
|
|
//
|
|
// NOTE: precision is the last digit that will not be truncated (must be >= 0).
|
|
//
|
|
// Example:
|
|
//
|
|
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
|
|
//
|
|
func (d Decimal) Truncate(precision int32) Decimal {
|
|
d.ensureInitialized()
|
|
if precision >= 0 && -precision > d.exp {
|
|
return d.rescale(-precision)
|
|
}
|
|
return d
|
|
}
|
|
|
|
// UnmarshalJSON implements the json.Unmarshaler interface.
|
|
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
|
|
if string(decimalBytes) == "null" {
|
|
return nil
|
|
}
|
|
|
|
str, err := unquoteIfQuoted(decimalBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err)
|
|
}
|
|
|
|
decimal, err := NewFromString(str)
|
|
*d = decimal
|
|
if err != nil {
|
|
return fmt.Errorf("Error decoding string '%s': %s", str, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MarshalJSON implements the json.Marshaler interface.
|
|
func (d Decimal) MarshalJSON() ([]byte, error) {
|
|
var str string
|
|
if MarshalJSONWithoutQuotes {
|
|
str = d.String()
|
|
} else {
|
|
str = "\"" + d.String() + "\""
|
|
}
|
|
return []byte(str), nil
|
|
}
|
|
|
|
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation
|
|
// is already used when encoding to text, this method stores that string as []byte
|
|
func (d *Decimal) UnmarshalBinary(data []byte) error {
|
|
// Extract the exponent
|
|
d.exp = int32(binary.BigEndian.Uint32(data[:4]))
|
|
|
|
// Extract the value
|
|
d.value = new(big.Int)
|
|
return d.value.GobDecode(data[4:])
|
|
}
|
|
|
|
// MarshalBinary implements the encoding.BinaryMarshaler interface.
|
|
func (d Decimal) MarshalBinary() (data []byte, err error) {
|
|
// Write the exponent first since it's a fixed size
|
|
v1 := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(v1, uint32(d.exp))
|
|
|
|
// Add the value
|
|
var v2 []byte
|
|
if v2, err = d.value.GobEncode(); err != nil {
|
|
return
|
|
}
|
|
|
|
// Return the byte array
|
|
data = append(v1, v2...)
|
|
return
|
|
}
|
|
|
|
// Scan implements the sql.Scanner interface for database deserialization.
|
|
func (d *Decimal) Scan(value interface{}) error {
|
|
// first try to see if the data is stored in database as a Numeric datatype
|
|
switch v := value.(type) {
|
|
|
|
case float32:
|
|
*d = NewFromFloat(float64(v))
|
|
return nil
|
|
|
|
case float64:
|
|
// numeric in sqlite3 sends us float64
|
|
*d = NewFromFloat(v)
|
|
return nil
|
|
|
|
case int64:
|
|
// at least in sqlite3 when the value is 0 in db, the data is sent
|
|
// to us as an int64 instead of a float64 ...
|
|
*d = New(v, 0)
|
|
return nil
|
|
|
|
default:
|
|
// default is trying to interpret value stored as string
|
|
str, err := unquoteIfQuoted(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*d, err = NewFromString(str)
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Value implements the driver.Valuer interface for database serialization.
|
|
func (d Decimal) Value() (driver.Value, error) {
|
|
return d.String(), nil
|
|
}
|
|
|
|
// UnmarshalText implements the encoding.TextUnmarshaler interface for XML
|
|
// deserialization.
|
|
func (d *Decimal) UnmarshalText(text []byte) error {
|
|
str := string(text)
|
|
|
|
dec, err := NewFromString(str)
|
|
*d = dec
|
|
if err != nil {
|
|
return fmt.Errorf("Error decoding string '%s': %s", str, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarshalText implements the encoding.TextMarshaler interface for XML
|
|
// serialization.
|
|
func (d Decimal) MarshalText() (text []byte, err error) {
|
|
return []byte(d.String()), nil
|
|
}
|
|
|
|
// GobEncode implements the gob.GobEncoder interface for gob serialization.
|
|
func (d Decimal) GobEncode() ([]byte, error) {
|
|
return d.MarshalBinary()
|
|
}
|
|
|
|
// GobDecode implements the gob.GobDecoder interface for gob serialization.
|
|
func (d *Decimal) GobDecode(data []byte) error {
|
|
return d.UnmarshalBinary(data)
|
|
}
|
|
|
|
// StringScaled first scales the decimal then calls .String() on it.
|
|
// NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead.
|
|
func (d Decimal) StringScaled(exp int32) string {
|
|
return d.rescale(exp).String()
|
|
}
|
|
|
|
func (d Decimal) string(trimTrailingZeros bool) string {
|
|
if d.exp >= 0 {
|
|
return d.rescale(0).value.String()
|
|
}
|
|
|
|
abs := new(big.Int).Abs(d.value)
|
|
str := abs.String()
|
|
|
|
var intPart, fractionalPart string
|
|
|
|
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
|
|
// and you are on a 32-bit machine. Won't fix this super-edge case.
|
|
dExpInt := int(d.exp)
|
|
if len(str) > -dExpInt {
|
|
intPart = str[:len(str)+dExpInt]
|
|
fractionalPart = str[len(str)+dExpInt:]
|
|
} else {
|
|
intPart = "0"
|
|
|
|
num0s := -dExpInt - len(str)
|
|
fractionalPart = strings.Repeat("0", num0s) + str
|
|
}
|
|
|
|
if trimTrailingZeros {
|
|
i := len(fractionalPart) - 1
|
|
for ; i >= 0; i-- {
|
|
if fractionalPart[i] != '0' {
|
|
break
|
|
}
|
|
}
|
|
fractionalPart = fractionalPart[:i+1]
|
|
}
|
|
|
|
number := intPart
|
|
if len(fractionalPart) > 0 {
|
|
number += "." + fractionalPart
|
|
}
|
|
|
|
if d.value.Sign() < 0 {
|
|
return "-" + number
|
|
}
|
|
|
|
return number
|
|
}
|
|
|
|
func (d *Decimal) ensureInitialized() {
|
|
if d.value == nil {
|
|
d.value = new(big.Int)
|
|
}
|
|
}
|
|
|
|
// Min returns the smallest Decimal that was passed in the arguments.
|
|
//
|
|
// To call this function with an array, you must do:
|
|
//
|
|
// Min(arr[0], arr[1:]...)
|
|
//
|
|
// This makes it harder to accidentally call Min with 0 arguments.
|
|
func Min(first Decimal, rest ...Decimal) Decimal {
|
|
ans := first
|
|
for _, item := range rest {
|
|
if item.Cmp(ans) < 0 {
|
|
ans = item
|
|
}
|
|
}
|
|
return ans
|
|
}
|
|
|
|
// Max returns the largest Decimal that was passed in the arguments.
|
|
//
|
|
// To call this function with an array, you must do:
|
|
//
|
|
// Max(arr[0], arr[1:]...)
|
|
//
|
|
// This makes it harder to accidentally call Max with 0 arguments.
|
|
func Max(first Decimal, rest ...Decimal) Decimal {
|
|
ans := first
|
|
for _, item := range rest {
|
|
if item.Cmp(ans) > 0 {
|
|
ans = item
|
|
}
|
|
}
|
|
return ans
|
|
}
|
|
|
|
// Sum returns the combined total of the provided first and rest Decimals
|
|
func Sum(first Decimal, rest ...Decimal) Decimal {
|
|
total := first
|
|
for _, item := range rest {
|
|
total = total.Add(item)
|
|
}
|
|
|
|
return total
|
|
}
|
|
|
|
// Avg returns the average value of the provided first and rest Decimals
|
|
func Avg(first Decimal, rest ...Decimal) Decimal {
|
|
count := New(int64(len(rest)+1), 0)
|
|
sum := Sum(first, rest...)
|
|
return sum.Div(count)
|
|
}
|
|
|
|
func min(x, y int32) int32 {
|
|
if x >= y {
|
|
return y
|
|
}
|
|
return x
|
|
}
|
|
|
|
func unquoteIfQuoted(value interface{}) (string, error) {
|
|
var bytes []byte
|
|
|
|
switch v := value.(type) {
|
|
case string:
|
|
bytes = []byte(v)
|
|
case []byte:
|
|
bytes = v
|
|
default:
|
|
return "", fmt.Errorf("Could not convert value '%+v' to byte array of type '%T'",
|
|
value, value)
|
|
}
|
|
|
|
// If the amount is quoted, strip the quotes
|
|
if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' {
|
|
bytes = bytes[1 : len(bytes)-1]
|
|
}
|
|
return string(bytes), nil
|
|
}
|
|
|
|
// NullDecimal represents a nullable decimal with compatibility for
|
|
// scanning null values from the database.
|
|
type NullDecimal struct {
|
|
Decimal Decimal
|
|
Valid bool
|
|
}
|
|
|
|
// Scan implements the sql.Scanner interface for database deserialization.
|
|
func (d *NullDecimal) Scan(value interface{}) error {
|
|
if value == nil {
|
|
d.Valid = false
|
|
return nil
|
|
}
|
|
d.Valid = true
|
|
return d.Decimal.Scan(value)
|
|
}
|
|
|
|
// Value implements the driver.Valuer interface for database serialization.
|
|
func (d NullDecimal) Value() (driver.Value, error) {
|
|
if !d.Valid {
|
|
return nil, nil
|
|
}
|
|
return d.Decimal.Value()
|
|
}
|
|
|
|
// UnmarshalJSON implements the json.Unmarshaler interface.
|
|
func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error {
|
|
if string(decimalBytes) == "null" {
|
|
d.Valid = false
|
|
return nil
|
|
}
|
|
d.Valid = true
|
|
return d.Decimal.UnmarshalJSON(decimalBytes)
|
|
}
|
|
|
|
// MarshalJSON implements the json.Marshaler interface.
|
|
func (d NullDecimal) MarshalJSON() ([]byte, error) {
|
|
if !d.Valid {
|
|
return []byte("null"), nil
|
|
}
|
|
return d.Decimal.MarshalJSON()
|
|
}
|