NullStringを綺麗にJSONエンコードするための稚拙なやり方

目的

タイトルの通り。
sql.NullStringはNullableなカラムを引き受けてくれる型だが、そのままJSONに引き渡してもオブジェクトとして出力されてしまう。
これを綺麗にJSONとして出力したかった。

綺麗でない例は以下のとおり

before

[{"id":1,"string":{"String":"aaa","Valid":true}},{"id":2,"string":{"String":"","Valid":false}}]

after

[{"id":1,"string":"aaa"},{"id":2,"string":""}]

方法

Null*型をJSONに引き渡して好きなかたちで出力したい場合、MarshalJSONを自分で実装すれば良い。問題はSQL実行後のScanでアサインされないことだった。

実装

取得元テーブル

サンプルなのでIDとNullableなカラムがあればよい

CREATE TABLE `test`.`testtable` (
 `id` INT NOT NULL AUTO_INCREMENT,
 `nullable` VARCHAR(45) NULL,
 PRIMARY KEY (`id`))

ソース

一旦ベタで貼る。短縮できる方法は試してみるが、考えるのが面倒だったのでconverterそのままコピペした。完全に良くない。

package main

import (
    "database/sql"
    "encoding/json"
    "os"
    "database/sql/driver"
    _ "github.com/go-sql-driver/mysql"
    log "github.com/Sirupsen/logrus"
    "fmt"
    "reflect"
    "strconv"
    "errors"
    "time"
)

type Test struct {
    Id int `json:"id"`
    Nullable LocalNullString `json:"string"`
}

func main() {

    db, err := sql.Open("mysql", "root:password@tcp(localhost:3306)/test")
    if err != nil {
        log.Fatalln(err)
    }

    defer db.Close()

    rows, err := db.Query("select id, nullable from test.testtable")
    if err != nil {
        log.Fatalln(err)
    }

    e := json.NewEncoder(os.Stdout)
    tests := []Test{}
    for rows.Next() {
        var t Test
        rows.Scan(&t.Id, &t.Nullable)

        tests = append(tests, t)
        e.Encode(t)
    }
    log.SetOutput(os.Stdout)
    log.Debug(json.Marshal(tests))
    e.Encode(tests)
}

type LocalNullString sql.NullString

func (l *LocalNullString) Scan(value interface{}) error {
    if value == nil {
        log.Info("appa==", value)
        l.String, l.Valid = "", false
        return nil
    } else {
        log.Info("bppa==", value)
        l.String = fmt.Sprint(value)
        l.Valid = true
        return convertAssign(&l.String, value)
    }
    return nil
}


func (l *LocalNullString) Value() (driver.Value, error) {
    if l.Valid {
        return l.Value()
    }
    return "", nil
}
func (l *LocalNullString) GetString() string {
    if l.Valid {
        return l.String
    }
    return ""
}

func (l *LocalNullString) MarshalJSON() ([]byte, error) {
    if l.Valid {
        return json.Marshal(l.String)
    }
    return json.Marshal("")
}

// ここからは完全にコピペなので見なくていい。

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Type conversions for Scan.

var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error

// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.
func convertAssign(dest, src interface{}) error {
    // Common cases, without reflect.
    switch s := src.(type) {
    case string:
        switch d := dest.(type) {
        case *string:
            if d == nil {
                return errNilPtr
            }
            *d = s
            return nil
        case *[]byte:
            if d == nil {
                return errNilPtr
            }
            *d = []byte(s)
            return nil
        }
    case []byte:
        switch d := dest.(type) {
        case *string:
            if d == nil {
                return errNilPtr
            }
            *d = string(s)
            return nil
        case *interface{}:
            if d == nil {
                return errNilPtr
            }
            *d = cloneBytes(s)
            return nil
        case *[]byte:
            if d == nil {
                return errNilPtr
            }
            *d = cloneBytes(s)
            return nil
        case *sql.RawBytes:
            if d == nil {
                return errNilPtr
            }
            *d = s
            return nil
        }
    case time.Time:
        var t time.Time = s
        switch d := dest.(type) {
        case *string:
            *d = t.Format(time.RFC3339Nano)
            return nil
        case *[]byte:
            if d == nil {
                return errNilPtr
            }
            *d = []byte(t.Format(time.RFC3339Nano))
            return nil
        }
    case nil:
        switch d := dest.(type) {
        case *interface{}:
            if d == nil {
                return errNilPtr
            }
            *d = nil
            return nil
        case *[]byte:
            if d == nil {
                return errNilPtr
            }
            *d = nil
            return nil
        case *sql.RawBytes:
            if d == nil {
                return errNilPtr
            }
            *d = nil
            return nil
        }
    }

    var sv reflect.Value

    switch d := dest.(type) {
    case *string:
        sv = reflect.ValueOf(src)
        switch sv.Kind() {
        case reflect.Bool,
            reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
            reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
            reflect.Float32, reflect.Float64:
            *d = asString(src)
            return nil
        }
    case *[]byte:
        sv = reflect.ValueOf(src)
        if b, ok := asBytes(nil, sv); ok {
            *d = b
            return nil
        }
    case *sql.RawBytes:
        sv = reflect.ValueOf(src)
        if b, ok := asBytes([]byte(*d)[:0], sv); ok {
            *d = sql.RawBytes(b)
            return nil
        }
    case *bool:
        bv, err := driver.Bool.ConvertValue(src)
        if err == nil {
            *d = bv.(bool)
        }
        return err
    case *interface{}:
        *d = src
        return nil
    }

    if scanner, ok := dest.(sql.Scanner); ok {
        return scanner.Scan(src)
    }

    dpv := reflect.ValueOf(dest)
    if dpv.Kind() != reflect.Ptr {
        return errors.New("destination not a pointer")
    }
    if dpv.IsNil() {
        return errNilPtr
    }

    if !sv.IsValid() {
        sv = reflect.ValueOf(src)
    }

    dv := reflect.Indirect(dpv)
    if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
        dv.Set(sv)
        return nil
    }

    if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
        dv.Set(sv.Convert(dv.Type()))
        return nil
    }

    switch dv.Kind() {
    case reflect.Ptr:
        if src == nil {
            dv.Set(reflect.Zero(dv.Type()))
            return nil
        } else {
            dv.Set(reflect.New(dv.Type().Elem()))
            return convertAssign(dv.Interface(), src)
        }
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        s := asString(src)
        i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
        if err != nil {
            err = strconvErr(err)
            return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
        }
        dv.SetInt(i64)
        return nil
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        s := asString(src)
        u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
        if err != nil {
            err = strconvErr(err)
            return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
        }
        dv.SetUint(u64)
        return nil
    case reflect.Float32, reflect.Float64:
        s := asString(src)
        f64, err := strconv.ParseFloat(s, dv.Type().Bits())
        if err != nil {
            err = strconvErr(err)
            return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
        }
        dv.SetFloat(f64)
        return nil
    }

    return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}

func strconvErr(err error) error {
    if ne, ok := err.(*strconv.NumError); ok {
        return ne.Err
    }
    return err
}

func cloneBytes(b []byte) []byte {
    if b == nil {
        return nil
    } else {
        c := make([]byte, len(b))
        copy(c, b)
        return c
    }
}

func asString(src interface{}) string {
    switch v := src.(type) {
    case string:
        return v
    case []byte:
        return string(v)
    }
    rv := reflect.ValueOf(src)
    switch rv.Kind() {
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return strconv.FormatInt(rv.Int(), 10)
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        return strconv.FormatUint(rv.Uint(), 10)
    case reflect.Float64:
        return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
    case reflect.Float32:
        return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
    case reflect.Bool:
        return strconv.FormatBool(rv.Bool())
    }
    return fmt.Sprintf("%v", src)
}

func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
    switch rv.Kind() {
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return strconv.AppendInt(buf, rv.Int(), 10), true
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        return strconv.AppendUint(buf, rv.Uint(), 10), true
    case reflect.Float32:
        return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
    case reflect.Float64:
        return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
    case reflect.Bool:
        return strconv.AppendBool(buf, rv.Bool()), true
    case reflect.String:
        s := rv.String()
        return append(buf, s...), true
    }
    return
}

goをもうちょっと勉強したいところであるが、仕事で触るのが覚えるための最速であるというのは大きく感じるところ。

このブログの人気の投稿

2016年にgoを使ったのでまとめ

採用とは何か

エンジニアの妻になってしまった我が妻には感謝している