initial commit
This commit is contained in:
12
decisionTree/db.go
Normal file
12
decisionTree/db.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package decisionTree
|
||||
|
||||
type Vertex struct {
|
||||
ID string
|
||||
Value string
|
||||
}
|
||||
|
||||
type Edge struct {
|
||||
SourceID string
|
||||
TargetID string
|
||||
Value string
|
||||
}
|
||||
62
decisionTree/graph.go
Normal file
62
decisionTree/graph.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package decisionTree
|
||||
|
||||
type GraphPart struct {
|
||||
Label string
|
||||
Children map[string]GraphPart `json:"Children,omitempty"`
|
||||
}
|
||||
|
||||
// Every question is a non-leaf vertex. Every answer is an edge. The root question is the one which is never a target
|
||||
// of an answer. We'll start by finding it by using that definition.
|
||||
var vertices = make(map[string]Vertex)
|
||||
var edges []Edge
|
||||
|
||||
func GraphFromDBRows(vs []Vertex, es []Edge) GraphPart {
|
||||
// As our graph is a simple tree (as opposed to polytree etc), there is always a single root node. It's not strictly
|
||||
//necessary to find that node up front, but it does help make this example simpler.
|
||||
rootQuestion := getRootVertex(vs, es)
|
||||
|
||||
/**
|
||||
Since building the graph is a recursive operation, we can keep the stack a bit lighter (and our function signatures
|
||||
simpler) by storing the questions and edges in package variables.
|
||||
*/
|
||||
edges = es
|
||||
|
||||
// Storing questions as a map for easy lookup by ID later
|
||||
for _, v := range vs {
|
||||
vertices[v.ID] = v
|
||||
}
|
||||
|
||||
// Now that we have the root question, we can recursively build the rest of the tree
|
||||
graph := GraphPart{Label: rootQuestion.Value, Children: make(map[string]GraphPart)}
|
||||
buildChildren(&graph, rootQuestion)
|
||||
|
||||
// And we can return the graph, which provides references to the full graph
|
||||
return graph
|
||||
}
|
||||
|
||||
func buildChildren(n *GraphPart, v Vertex) {
|
||||
for _, e := range edges {
|
||||
if e.SourceID == v.ID {
|
||||
child := GraphPart{Label: vertices[e.TargetID].Value, Children: make(map[string]GraphPart)}
|
||||
buildChildren(&child, vertices[e.TargetID])
|
||||
n.Children[e.Value] = child
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getRootVertex(vs []Vertex, es []Edge) Vertex {
|
||||
targetIDs := map[string]struct{}{}
|
||||
|
||||
for _, a := range es {
|
||||
targetIDs[a.TargetID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, v := range vs {
|
||||
if _, present := targetIDs[v.ID]; !present {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// A lazy panic. In the real world we should do nice error handling.
|
||||
panic("couldn't find the root question")
|
||||
}
|
||||
164
decisionTree/graph_test.go
Normal file
164
decisionTree/graph_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package decisionTree
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGraphFromDBRows(t *testing.T) {
|
||||
type args struct {
|
||||
vs []Vertex
|
||||
es []Edge
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantJson string
|
||||
}{
|
||||
{
|
||||
name: "simple parent and child",
|
||||
args: args{
|
||||
vs: []Vertex{
|
||||
Vertex{
|
||||
ID: "root",
|
||||
Value: "Are you human?",
|
||||
},
|
||||
Vertex{
|
||||
ID: "leaf",
|
||||
Value: "Congratulations",
|
||||
},
|
||||
},
|
||||
es: []Edge{
|
||||
Edge{
|
||||
SourceID: "root",
|
||||
TargetID: "leaf",
|
||||
Value: "Yes",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantJson: `{
|
||||
"Label": "Are you human?",
|
||||
"Children": {
|
||||
"Yes": {
|
||||
"Label": "Congratulations"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Multiple decisions",
|
||||
args: args{
|
||||
vs: []Vertex{
|
||||
Vertex{
|
||||
ID: "root",
|
||||
Value: "What do you prefer?",
|
||||
},
|
||||
Vertex{
|
||||
ID: "cars",
|
||||
Value: "What kind?",
|
||||
},
|
||||
Vertex{
|
||||
ID: "boats",
|
||||
Value: "What kind?",
|
||||
},
|
||||
Vertex{
|
||||
ID: "trains",
|
||||
Value: "https://www.youtube.com/watch?v=hHkKJfcBXcw",
|
||||
},
|
||||
Vertex{
|
||||
ID: "sportscars",
|
||||
Value: "You should consider Porsche",
|
||||
},
|
||||
Vertex{
|
||||
ID: "luxurycars",
|
||||
Value: "You should consider Mercedes",
|
||||
},
|
||||
Vertex{
|
||||
ID: "sailboats",
|
||||
Value: "You should consider Beneteau",
|
||||
},
|
||||
Vertex{
|
||||
ID: "powerboats",
|
||||
Value: "You should consider Formula",
|
||||
},
|
||||
},
|
||||
es: []Edge{
|
||||
Edge{
|
||||
SourceID: "root",
|
||||
TargetID: "cars",
|
||||
Value: "Cars",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "root",
|
||||
TargetID: "boats",
|
||||
Value: "Boats",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "root",
|
||||
TargetID: "trains",
|
||||
Value: "Trains",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "cars",
|
||||
TargetID: "sportscars",
|
||||
Value: "Sporty cars",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "cars",
|
||||
TargetID: "luxurycars",
|
||||
Value: "Luxury cars",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "boats",
|
||||
TargetID: "sailboats",
|
||||
Value: "Sailboats",
|
||||
},
|
||||
Edge{
|
||||
SourceID: "boats",
|
||||
TargetID: "powerboats",
|
||||
Value: "Power boats",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantJson: `{
|
||||
"Label": "What do you prefer?",
|
||||
"Children": {
|
||||
"Boats": {
|
||||
"Label": "What kind?",
|
||||
"Children": {
|
||||
"Power boats": {
|
||||
"Label": "You should consider Formula"
|
||||
},
|
||||
"Sailboats": {
|
||||
"Label": "You should consider Beneteau"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Cars": {
|
||||
"Label": "What kind?",
|
||||
"Children": {
|
||||
"Luxury cars": {
|
||||
"Label": "You should consider Mercedes"
|
||||
},
|
||||
"Sporty cars": {
|
||||
"Label": "You should consider Porsche"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Trains": {
|
||||
"Label": "https://www.youtube.com/watch?v=hHkKJfcBXcw"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GraphFromDBRows(tt.args.vs, tt.args.es)
|
||||
gotJson, _ := json.MarshalIndent(got, "", "\t")
|
||||
if string(gotJson) != tt.wantJson {
|
||||
t.Errorf("Got `%s`, want `%s`", gotJson, tt.wantJson)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user