package decisionTree import ( "encoding/json" "testing" ) func TestGraphFromDBRows(t *testing.T) { type args struct { vs []Vertex es []Edge } tests := []struct { name string args args wantJson string }{ { name: "simple parent and child", args: args{ vs: []Vertex{ Vertex{ ID: "root", Value: "Are you human?", }, Vertex{ ID: "leaf", Value: "Congratulations", }, }, es: []Edge{ Edge{ SourceID: "root", TargetID: "leaf", Value: "Yes", }, }, }, wantJson: `{ "Label": "Are you human?", "Children": { "Yes": { "Label": "Congratulations" } } }`, }, { name: "Multiple decisions", args: args{ vs: []Vertex{ Vertex{ ID: "root", Value: "What do you prefer?", }, Vertex{ ID: "cars", Value: "What kind?", }, Vertex{ ID: "boats", Value: "What kind?", }, Vertex{ ID: "trains", Value: "https://www.youtube.com/watch?v=hHkKJfcBXcw", }, Vertex{ ID: "sportscars", Value: "You should consider Porsche", }, Vertex{ ID: "luxurycars", Value: "You should consider Mercedes", }, Vertex{ ID: "sailboats", Value: "You should consider Beneteau", }, Vertex{ ID: "powerboats", Value: "You should consider Formula", }, }, es: []Edge{ Edge{ SourceID: "root", TargetID: "cars", Value: "Cars", }, Edge{ SourceID: "root", TargetID: "boats", Value: "Boats", }, Edge{ SourceID: "root", TargetID: "trains", Value: "Trains", }, Edge{ SourceID: "cars", TargetID: "sportscars", Value: "Sporty cars", }, Edge{ SourceID: "cars", TargetID: "luxurycars", Value: "Luxury cars", }, Edge{ SourceID: "boats", TargetID: "sailboats", Value: "Sailboats", }, Edge{ SourceID: "boats", TargetID: "powerboats", Value: "Power boats", }, }, }, wantJson: `{ "Label": "What do you prefer?", "Children": { "Boats": { "Label": "What kind?", "Children": { "Power boats": { "Label": "You should consider Formula" }, "Sailboats": { "Label": "You should consider Beneteau" } } }, "Cars": { "Label": "What kind?", "Children": { "Luxury cars": { "Label": "You should consider Mercedes" }, "Sporty cars": { "Label": "You should consider Porsche" } } }, "Trains": { "Label": "https://www.youtube.com/watch?v=hHkKJfcBXcw" } } }`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := GraphFromDBRows(tt.args.vs, tt.args.es) gotJson, _ := json.MarshalIndent(got, "", "\t") if string(gotJson) != tt.wantJson { t.Errorf("Got `%s`, want `%s`", gotJson, tt.wantJson) } }) } }