initial commit

This commit is contained in:
Ben Vezzani
2024-09-19 22:20:30 -04:00
commit e699bde720
4 changed files with 239 additions and 0 deletions

12
decisionTree/db.go Normal file
View File

@@ -0,0 +1,12 @@
package decisionTree
type Vertex struct {
ID string
Value string
}
type Edge struct {
SourceID string
TargetID string
Value string
}

62
decisionTree/graph.go Normal file
View File

@@ -0,0 +1,62 @@
package decisionTree
type GraphPart struct {
Label string
Children map[string]GraphPart `json:"Children,omitempty"`
}
// Every question is a non-leaf vertex. Every answer is an edge. The root question is the one which is never a target
// of an answer. We'll start by finding it by using that definition.
var vertices = make(map[string]Vertex)
var edges []Edge
func GraphFromDBRows(vs []Vertex, es []Edge) GraphPart {
// As our graph is a simple tree (as opposed to polytree etc), there is always a single root node. It's not strictly
//necessary to find that node up front, but it does help make this example simpler.
rootQuestion := getRootVertex(vs, es)
/**
Since building the graph is a recursive operation, we can keep the stack a bit lighter (and our function signatures
simpler) by storing the questions and edges in package variables.
*/
edges = es
// Storing questions as a map for easy lookup by ID later
for _, v := range vs {
vertices[v.ID] = v
}
// Now that we have the root question, we can recursively build the rest of the tree
graph := GraphPart{Label: rootQuestion.Value, Children: make(map[string]GraphPart)}
buildChildren(&graph, rootQuestion)
// And we can return the graph, which provides references to the full graph
return graph
}
func buildChildren(n *GraphPart, v Vertex) {
for _, e := range edges {
if e.SourceID == v.ID {
child := GraphPart{Label: vertices[e.TargetID].Value, Children: make(map[string]GraphPart)}
buildChildren(&child, vertices[e.TargetID])
n.Children[e.Value] = child
}
}
}
func getRootVertex(vs []Vertex, es []Edge) Vertex {
targetIDs := map[string]struct{}{}
for _, a := range es {
targetIDs[a.TargetID] = struct{}{}
}
for _, v := range vs {
if _, present := targetIDs[v.ID]; !present {
return v
}
}
// A lazy panic. In the real world we should do nice error handling.
panic("couldn't find the root question")
}

164
decisionTree/graph_test.go Normal file
View File

@@ -0,0 +1,164 @@
package decisionTree
import (
"encoding/json"
"testing"
)
func TestGraphFromDBRows(t *testing.T) {
type args struct {
vs []Vertex
es []Edge
}
tests := []struct {
name string
args args
wantJson string
}{
{
name: "simple parent and child",
args: args{
vs: []Vertex{
Vertex{
ID: "root",
Value: "Are you human?",
},
Vertex{
ID: "leaf",
Value: "Congratulations",
},
},
es: []Edge{
Edge{
SourceID: "root",
TargetID: "leaf",
Value: "Yes",
},
},
},
wantJson: `{
"Label": "Are you human?",
"Children": {
"Yes": {
"Label": "Congratulations"
}
}
}`,
},
{
name: "Multiple decisions",
args: args{
vs: []Vertex{
Vertex{
ID: "root",
Value: "What do you prefer?",
},
Vertex{
ID: "cars",
Value: "What kind?",
},
Vertex{
ID: "boats",
Value: "What kind?",
},
Vertex{
ID: "trains",
Value: "https://www.youtube.com/watch?v=hHkKJfcBXcw",
},
Vertex{
ID: "sportscars",
Value: "You should consider Porsche",
},
Vertex{
ID: "luxurycars",
Value: "You should consider Mercedes",
},
Vertex{
ID: "sailboats",
Value: "You should consider Beneteau",
},
Vertex{
ID: "powerboats",
Value: "You should consider Formula",
},
},
es: []Edge{
Edge{
SourceID: "root",
TargetID: "cars",
Value: "Cars",
},
Edge{
SourceID: "root",
TargetID: "boats",
Value: "Boats",
},
Edge{
SourceID: "root",
TargetID: "trains",
Value: "Trains",
},
Edge{
SourceID: "cars",
TargetID: "sportscars",
Value: "Sporty cars",
},
Edge{
SourceID: "cars",
TargetID: "luxurycars",
Value: "Luxury cars",
},
Edge{
SourceID: "boats",
TargetID: "sailboats",
Value: "Sailboats",
},
Edge{
SourceID: "boats",
TargetID: "powerboats",
Value: "Power boats",
},
},
},
wantJson: `{
"Label": "What do you prefer?",
"Children": {
"Boats": {
"Label": "What kind?",
"Children": {
"Power boats": {
"Label": "You should consider Formula"
},
"Sailboats": {
"Label": "You should consider Beneteau"
}
}
},
"Cars": {
"Label": "What kind?",
"Children": {
"Luxury cars": {
"Label": "You should consider Mercedes"
},
"Sporty cars": {
"Label": "You should consider Porsche"
}
}
},
"Trains": {
"Label": "https://www.youtube.com/watch?v=hHkKJfcBXcw"
}
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GraphFromDBRows(tt.args.vs, tt.args.es)
gotJson, _ := json.MarshalIndent(got, "", "\t")
if string(gotJson) != tt.wantJson {
t.Errorf("Got `%s`, want `%s`", gotJson, tt.wantJson)
}
})
}
}

1
go.mod Normal file
View File

@@ -0,0 +1 @@
module random-stuff