commit e699bde72099fc94253f24f582bb30ccbc124c86 Author: Ben Vezzani Date: Thu Sep 19 22:20:30 2024 -0400 initial commit diff --git a/decisionTree/db.go b/decisionTree/db.go new file mode 100644 index 0000000..acb557b --- /dev/null +++ b/decisionTree/db.go @@ -0,0 +1,12 @@ +package decisionTree + +type Vertex struct { + ID string + Value string +} + +type Edge struct { + SourceID string + TargetID string + Value string +} diff --git a/decisionTree/graph.go b/decisionTree/graph.go new file mode 100644 index 0000000..57a16f5 --- /dev/null +++ b/decisionTree/graph.go @@ -0,0 +1,62 @@ +package decisionTree + +type GraphPart struct { + Label string + Children map[string]GraphPart `json:"Children,omitempty"` +} + +// Every question is a non-leaf vertex. Every answer is an edge. The root question is the one which is never a target +// of an answer. We'll start by finding it by using that definition. +var vertices = make(map[string]Vertex) +var edges []Edge + +func GraphFromDBRows(vs []Vertex, es []Edge) GraphPart { + // As our graph is a simple tree (as opposed to polytree etc), there is always a single root node. It's not strictly + //necessary to find that node up front, but it does help make this example simpler. + rootQuestion := getRootVertex(vs, es) + + /** + Since building the graph is a recursive operation, we can keep the stack a bit lighter (and our function signatures + simpler) by storing the questions and edges in package variables. + */ + edges = es + + // Storing questions as a map for easy lookup by ID later + for _, v := range vs { + vertices[v.ID] = v + } + + // Now that we have the root question, we can recursively build the rest of the tree + graph := GraphPart{Label: rootQuestion.Value, Children: make(map[string]GraphPart)} + buildChildren(&graph, rootQuestion) + + // And we can return the graph, which provides references to the full graph + return graph +} + +func buildChildren(n *GraphPart, v Vertex) { + for _, e := range edges { + if e.SourceID == v.ID { + child := GraphPart{Label: vertices[e.TargetID].Value, Children: make(map[string]GraphPart)} + buildChildren(&child, vertices[e.TargetID]) + n.Children[e.Value] = child + } + } +} + +func getRootVertex(vs []Vertex, es []Edge) Vertex { + targetIDs := map[string]struct{}{} + + for _, a := range es { + targetIDs[a.TargetID] = struct{}{} + } + + for _, v := range vs { + if _, present := targetIDs[v.ID]; !present { + return v + } + } + + // A lazy panic. In the real world we should do nice error handling. + panic("couldn't find the root question") +} diff --git a/decisionTree/graph_test.go b/decisionTree/graph_test.go new file mode 100644 index 0000000..b85e237 --- /dev/null +++ b/decisionTree/graph_test.go @@ -0,0 +1,164 @@ +package decisionTree + +import ( + "encoding/json" + "testing" +) + +func TestGraphFromDBRows(t *testing.T) { + type args struct { + vs []Vertex + es []Edge + } + tests := []struct { + name string + args args + wantJson string + }{ + { + name: "simple parent and child", + args: args{ + vs: []Vertex{ + Vertex{ + ID: "root", + Value: "Are you human?", + }, + Vertex{ + ID: "leaf", + Value: "Congratulations", + }, + }, + es: []Edge{ + Edge{ + SourceID: "root", + TargetID: "leaf", + Value: "Yes", + }, + }, + }, + wantJson: `{ + "Label": "Are you human?", + "Children": { + "Yes": { + "Label": "Congratulations" + } + } +}`, + }, + { + name: "Multiple decisions", + args: args{ + vs: []Vertex{ + Vertex{ + ID: "root", + Value: "What do you prefer?", + }, + Vertex{ + ID: "cars", + Value: "What kind?", + }, + Vertex{ + ID: "boats", + Value: "What kind?", + }, + Vertex{ + ID: "trains", + Value: "https://www.youtube.com/watch?v=hHkKJfcBXcw", + }, + Vertex{ + ID: "sportscars", + Value: "You should consider Porsche", + }, + Vertex{ + ID: "luxurycars", + Value: "You should consider Mercedes", + }, + Vertex{ + ID: "sailboats", + Value: "You should consider Beneteau", + }, + Vertex{ + ID: "powerboats", + Value: "You should consider Formula", + }, + }, + es: []Edge{ + Edge{ + SourceID: "root", + TargetID: "cars", + Value: "Cars", + }, + Edge{ + SourceID: "root", + TargetID: "boats", + Value: "Boats", + }, + Edge{ + SourceID: "root", + TargetID: "trains", + Value: "Trains", + }, + Edge{ + SourceID: "cars", + TargetID: "sportscars", + Value: "Sporty cars", + }, + Edge{ + SourceID: "cars", + TargetID: "luxurycars", + Value: "Luxury cars", + }, + Edge{ + SourceID: "boats", + TargetID: "sailboats", + Value: "Sailboats", + }, + Edge{ + SourceID: "boats", + TargetID: "powerboats", + Value: "Power boats", + }, + }, + }, + wantJson: `{ + "Label": "What do you prefer?", + "Children": { + "Boats": { + "Label": "What kind?", + "Children": { + "Power boats": { + "Label": "You should consider Formula" + }, + "Sailboats": { + "Label": "You should consider Beneteau" + } + } + }, + "Cars": { + "Label": "What kind?", + "Children": { + "Luxury cars": { + "Label": "You should consider Mercedes" + }, + "Sporty cars": { + "Label": "You should consider Porsche" + } + } + }, + "Trains": { + "Label": "https://www.youtube.com/watch?v=hHkKJfcBXcw" + } + } +}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GraphFromDBRows(tt.args.vs, tt.args.es) + gotJson, _ := json.MarshalIndent(got, "", "\t") + if string(gotJson) != tt.wantJson { + t.Errorf("Got `%s`, want `%s`", gotJson, tt.wantJson) + } + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7028695 --- /dev/null +++ b/go.mod @@ -0,0 +1 @@ +module random-stuff