Skip to content

Commit 4e0858a

Browse files
committed
Implement user defined functions
1 parent 9e4fc75 commit 4e0858a

File tree

4 files changed

+187
-28
lines changed

4 files changed

+187
-28
lines changed

lispy/builtin.go

+39-4
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,39 @@ func builtinOp(e *lenv, a *lval, op string) *lval {
3838
}
3939

4040
func builtinDef(e *lenv, a *lval) *lval {
41+
return builtinVar(e, a, "def")
42+
}
43+
44+
func builtinPut(e *lenv, a *lval) *lval {
45+
return builtinVar(e, a, "=")
46+
}
47+
48+
func builtinVar(e *lenv, a *lval, function string) *lval {
4149
if a.cells[0].ltype != lvalQexprType {
42-
return lvalErr("Function 'def' passed incorrect type: %s", a.cells[0].ltypeName())
50+
return lvalErr("Function %s passed incorrect type: %s", function, a.cells[0].ltypeName())
4351
}
4452
// First argument is symbol list
4553
syms := a.cells[0]
4654
// Ensure elements of first list are symbols
4755
for _, cell := range syms.cells {
4856
if cell.ltype != lvalSymType {
49-
return lvalErr("Function 'def' cannot define non-symbol: %s", cell.ltypeName())
57+
return lvalErr("Function %s cannot define non-symbol: %s", function, cell.ltypeName())
5058
}
5159
}
5260
// Check for the correct number of symbols and values
5361
if syms.cellCount() != a.cellCount()-1 {
54-
return lvalErr("Function 'def' cannot define incorrect number of values to symbols")
62+
return lvalErr("Function %s cannot define incorrect number of values to symbols", function)
5563
}
5664
// Assign copies of values to symbols
5765
for i, cell := range syms.cells {
58-
e.lenvPut(cell, a.cells[i+1])
66+
// 'def' to define globally
67+
if function == "def" {
68+
e.lenvDef(cell, a.cells[i+1])
69+
}
70+
// 'put' to define locally
71+
if function == "=" {
72+
e.lenvPut(cell, a.cells[i+1])
73+
}
5974
}
6075
return lvalSexpr()
6176
}
@@ -126,6 +141,26 @@ func builtinJoin(e *lenv, a *lval) *lval {
126141
return x
127142
}
128143

144+
func builtinLambda(e *lenv, a *lval) *lval {
145+
if a.cellCount() != 2 {
146+
return lvalErr("Lambda has %d arguments, not 2 as expected", a.cellCount())
147+
} else if a.cells[0].ltype != lvalQexprType {
148+
return lvalErr("Lambda cell[0] has unexpected type %d", a.cells[0].ltype)
149+
} else if a.cells[1].ltype != lvalQexprType {
150+
return lvalErr("Lambda cell[1] has unexpected type %d", a.cells[1].ltype)
151+
}
152+
// Check that the first Q-expression contains only Symbols
153+
for _, cell := range a.cells[0].cells {
154+
if cell.ltype != lvalSymType {
155+
return lvalErr("Cannot define non-symbol. Got type %s instead", cell.ltype)
156+
}
157+
}
158+
// Pop first 2 arguments and pass them to lvalLambda
159+
formals := a.lvalPop(0)
160+
body := a.lvalPop(0)
161+
return lvalLambda(formals, body)
162+
}
163+
129164
func builtinAdd(e *lenv, a *lval) *lval {
130165
return builtinOp(e, a, "+")
131166
}

lispy/lenv.go

+29-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package lispy
33
import "fmt"
44

55
type lenv struct {
6+
par *lenv
67
syms []string
78
vals []*lval
89
}
910

1011
func lenvNew() *lenv {
1112
e := new(lenv)
13+
e.par = nil
1214
e.syms = nil
1315
e.vals = nil
1416
return e
@@ -22,12 +24,26 @@ func (e *lenv) count() int {
2224
return len(e.syms)
2325
}
2426

27+
func lenvCopy(e *lenv) *lenv {
28+
n := new(lenv)
29+
n.par = e.par
30+
for i := 0; i < e.count(); i++ {
31+
n.syms = append(n.syms, string(e.syms[i]))
32+
n.vals = append(n.vals, lvalCopy(e.vals[i]))
33+
}
34+
return n
35+
}
36+
2537
func (e *lenv) lenvGet(k *lval) *lval {
26-
for i, sym := range e.syms {
27-
if sym == k.sym {
38+
for i := 0; i < e.count(); i++ {
39+
if e.syms[i] == k.sym {
2840
return lvalCopy(e.vals[i])
2941
}
3042
}
43+
// If no symbol is found yet, check in the parent
44+
if e.par != nil {
45+
return e.par.lenvGet(k)
46+
}
3147
return lvalErr("Unbound Symbol: '%s'", k.sym)
3248
}
3349

@@ -44,6 +60,15 @@ func (e *lenv) lenvPut(k, v *lval) {
4460
e.syms = append(e.syms, k.sym)
4561
}
4662

63+
func (e *lenv) lenvDef(k *lval, v *lval) {
64+
// Find top parent
65+
for e.par != nil {
66+
e = e.par
67+
}
68+
// Put value in e
69+
e.lenvPut(k, v)
70+
}
71+
4772
func (e *lenv) lenvAddBuiltin(name string, function lbuiltin) {
4873
k := lvalSym(name)
4974
v := lvalFun(function)
@@ -53,6 +78,8 @@ func (e *lenv) lenvAddBuiltin(name string, function lbuiltin) {
5378
func (e *lenv) lenvAddBuiltins() {
5479
// List Functions
5580
e.lenvAddBuiltin("def", builtinDef)
81+
e.lenvAddBuiltin("=", builtinPut)
82+
e.lenvAddBuiltin("\\", builtinLambda)
5683
e.lenvAddBuiltin("list", builtinList)
5784
e.lenvAddBuiltin("head", builtinHead)
5885
e.lenvAddBuiltin("tail", builtinTail)

lispy/lispy_test.go

+36-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ func TestStringOutput(t *testing.T) {
7777
{"eval (tail {tail tail {5 6 7}})", "{6 7}"},
7878
{"eval (head {(+ 1 2) (+ 10 20)})", "3"},
7979
{"eval (head {5 10 11 15})", "5"},
80-
{"+", "<function>"},
81-
{"eval (head {+ - = - * /})", "<function>"},
80+
{"+", "<builtin>"},
81+
{"eval (head {+ - = - * /})", "<builtin>"},
8282
{"(eval (head {+ - = - * /})) 10 20", "30"},
8383
{"hello", "Error: Unbound Symbol: 'hello'"},
8484
{"+ 1 {5 6 7}", "Error: Cannot operate on non-number: Q-Expression"},
@@ -148,3 +148,37 @@ func TestError(t *testing.T) {
148148
}
149149
}
150150
}
151+
152+
func TestFunctionDefinitions(t *testing.T) {
153+
l := InitLispy()
154+
defer CleanLispy(l)
155+
156+
cases := []struct {
157+
input string
158+
want string
159+
}{
160+
{"(\\ {x y} {+ x y})", "(\\ {x y} {+ x y})"},
161+
{"(\\ {x y} {+ x y}) 10 20", "30"},
162+
{"def {add-together} (\\ {x y} {+ x y})", "()"},
163+
{"add-together", "(\\ {x y} {+ x y})"},
164+
{"add-together 10 20", "30"},
165+
{"add-together", "(\\ {x y} {+ x y})"}, // Check for accidental modification
166+
{"def {add-mul} (\\ {x y} {+ x (* x y)})", "()"},
167+
{"add-mul", "(\\ {x y} {+ x (* x y)})"},
168+
{"add-mul 10 20", "210"},
169+
{"add-mul 10", "(\\ {y} {+ x (* x y)})"},
170+
{"def {add-mul-ten} (add-mul 10)", "()"},
171+
{"add-mul-ten", "(\\ {y} {+ x (* x y)})"},
172+
{"add-mul 10 50", "510"},
173+
{"add-mul-ten 50", "510"},
174+
{"add-mul", "(\\ {x y} {+ x (* x y)})"}, // Check for accidental modification
175+
{"add-mul-ten", "(\\ {y} {+ x (* x y)})"}, // Check for accidental modification
176+
}
177+
178+
for _, c := range cases {
179+
got := l.ReadEval(c.input, false)
180+
if got.lvalString() != c.want {
181+
t.Errorf("ReadEval input: \"%s\" returned: \"%s\", actually expected: \"%s\"", c.input, got.lvalString(), c.want)
182+
}
183+
}
184+
}

lispy/lval.go

+83-20
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@ const (
1919
)
2020

2121
type lval struct {
22-
ltype int
23-
num int64 // lvalNumType
24-
err string // lvalErrType
25-
sym string // lvalSymType
26-
function lbuiltin // lvalFunType
27-
cells []*lval // lvalSexprType, lvalQexprType
22+
ltype int
23+
24+
// Basic
25+
num int64 // lvalNumType
26+
err string // lvalErrType
27+
sym string // lvalSymType
28+
29+
// Function
30+
builtin lbuiltin // lvalFunType, nil for user defined function
31+
env *lenv
32+
formals *lval
33+
body *lval
34+
35+
// Expression
36+
cells []*lval // lvalSexprType, lvalQexprType
2837
}
2938

3039
// lvalNum creates an lval number
@@ -71,7 +80,21 @@ func lvalQexpr() *lval {
7180
func lvalFun(function lbuiltin) *lval {
7281
v := new(lval)
7382
v.ltype = lvalFunType
74-
v.function = function
83+
v.builtin = function
84+
return v
85+
}
86+
87+
// lvalLambda creates a user defined lval function
88+
func lvalLambda(formals *lval, body *lval) *lval {
89+
v := new(lval)
90+
v.ltype = lvalFunType
91+
// Builtin is nil for user defined functions
92+
v.builtin = nil
93+
// Init environment
94+
v.env = lenvNew()
95+
// Set formals and body
96+
v.formals = formals
97+
v.body = body
7598
return v
7699
}
77100

@@ -110,7 +133,10 @@ func (v *lval) lvalString() string {
110133
case lvalSymType:
111134
return (v.sym)
112135
case lvalFunType:
113-
return "<function>"
136+
if v.builtin == nil {
137+
return "(\\ " + v.formals.lvalString() + " " + v.body.lvalString() + ")"
138+
}
139+
return "<builtin>"
114140
case lvalSexprType:
115141
return v.lvalExprString("(", ")")
116142
case lvalQexprType:
@@ -133,7 +159,14 @@ func lvalCopy(v *lval) *lval {
133159
x.ltype = v.ltype
134160
switch v.ltype {
135161
case lvalFunType:
136-
x.function = v.function
162+
if v.builtin == nil {
163+
x.builtin = nil
164+
x.env = lenvCopy(v.env)
165+
x.formals = lvalCopy(v.formals)
166+
x.body = lvalCopy(v.body)
167+
} else {
168+
x.builtin = v.builtin
169+
}
137170
case lvalNumType:
138171
x.num = v.num
139172
case lvalErrType:
@@ -144,18 +177,15 @@ func lvalCopy(v *lval) *lval {
144177
fallthrough
145178
case lvalQexprType:
146179
for _, cell := range v.cells {
147-
x.cells = append(x.cells, cell)
180+
x.cells = append(x.cells, lvalCopy(cell))
148181
}
149182
}
150183
return x
151184
}
152185

153-
func (v *lval) lvalAdd(x *lval) {
154-
if x == nil {
155-
fmt.Println("ERROR: Failed to add lval, addition is nil")
156-
} else {
157-
v.cells = append(v.cells, x)
158-
}
186+
func lvalAdd(v *lval, x *lval) *lval {
187+
v.cells = append(v.cells, x)
188+
return v
159189
}
160190

161191
func (v *lval) lvalExprString(openChar string, closeChar string) string {
@@ -210,7 +240,7 @@ func lvalRead(tree mpc.MpcAst) *lval {
210240
mpc.GetTag(iChild) == "regex" {
211241
continue
212242
} else {
213-
x.lvalAdd(lvalRead(iChild))
243+
x = lvalAdd(x, lvalRead(iChild))
214244
}
215245
strconv.ParseInt(mpc.GetContents(tree), 10, 0)
216246
}
@@ -234,15 +264,15 @@ func (v *lval) lvalEvalSexpr(e *lenv) *lval {
234264
}
235265
// Single Expression
236266
if v.cellCount() == 1 {
237-
return v.lvalTake(0)
267+
return v.lvalTake(0).lvalEval(e)
238268
}
239269
// Ensure first element is a symbol
240270
f := v.lvalPop(0)
241271
if f.ltype != lvalFunType {
242272
return lvalErr("S-expression does not start with symbol! got: %s", f.ltypeName())
243273
}
244274
// Use first element as a function to get result
245-
return f.function(e, v)
275+
return lvalCall(e, f, v)
246276
}
247277

248278
func (v *lval) lvalEval(e *lenv) *lval {
@@ -269,7 +299,40 @@ func (v *lval) lvalTake(i int) *lval {
269299

270300
func lvalJoin(x *lval, y *lval) *lval {
271301
for y.cellCount() > 0 {
272-
x.lvalAdd(y.lvalPop(0))
302+
x = lvalAdd(x, y.lvalPop(0))
273303
}
274304
return x
275305
}
306+
307+
func lvalCall(e *lenv, f *lval, a *lval) *lval {
308+
// Simple Builtin case:
309+
if f.builtin != nil {
310+
return f.builtin(e, a)
311+
}
312+
// Record argument counts
313+
given := a.cellCount()
314+
total := f.formals.cellCount()
315+
// While arguments still remain to be processed
316+
for a.cellCount() > 0 {
317+
// If we've ran out of formal arguments to bind
318+
if f.formals.cellCount() == 0 {
319+
return lvalErr("Function passed too many arguments. Got %d, Expected %d", given, total)
320+
}
321+
// Pop the first symbol from the formal
322+
sym := f.formals.lvalPop(0)
323+
// Pop the next argument from the list
324+
val := a.lvalPop(0)
325+
// Bind a copy into the function's environment
326+
f.env.lenvPut(sym, val)
327+
}
328+
// If all formals have been bound, evaluate
329+
if f.formals.cellCount() == 0 {
330+
// Set environment parent to evaluation environment
331+
f.env.par = e
332+
// Evaluate and return
333+
ret := builtinEval(f.env, lvalAdd(lvalSexpr(), lvalCopy(f.body)))
334+
return ret
335+
}
336+
// Otherwise, return partially evaluated function
337+
return lvalCopy(f)
338+
}

0 commit comments

Comments
 (0)