Golang Code Example
2018年11月12日
echo
Unix 里 echo 命令的一份实现,echo 把它的命令行参数打印成一行
echo1.go
1package main 2 3import ( 4 "fmt" 5 "os" 6) 7 8func main() { 9 var s, sep string 10 for i := 1; i < len(os.Args); i++ { 11 s += sep + os.Args[i] 12 sep = " " 13 } 14 fmt.Println(s) 15}
echo2.go
1package main 2 3import ( 4 "fmt" 5 "os" 6) 7 8func main() { 9 s, sep := "", "" 10 for _, arg := range os.Args[1:] { 11 s += sep + arg 12 sep = " " 13 } 14 fmt.Println(s) 15}
echo3.go
1package main 2 3import ( 4 "fmt" 5 "strings" 6 "os" 7) 8 9func main() { 10 fmt.Println(strings.Join(os.Args[1:], " ")) 11}
dup
dup 的第一个版本打印标准输入中多次出现的行,以重复次数开头
dup1.go
1package main 2 3import ( 4 "bufio" 5 "fmt" 6 "os" 7) 8 9func main() { 10 counts := make(map[string]int) 11 input := bufio.NewScanner(os.Stdin) 12 for input.Scan() { 13 counts[input.Text()]++ 14 } 15 for line, n := range counts { 16 if n > 1 { 17 fmt.Printf("%d\t%s\n", n, line) 18 } 19 } 20}
读取标准输入或是使用 os.Open 打开各个具名文件,并操作它们
dup2.go
1package main 2 3import ( 4 "bufio" 5 "fmt" 6 "os" 7) 8 9func main() { 10 counts := make(map[string]int) 11 files := os.Args[1:] 12 if len(files) == 0 { 13 countLines(os.Stdin, counts) 14 } else { 15 for _, arg := range files { 16 f, err := os.Open(arg) 17 if err != nil { 18 fmt.Fprintf(os.Stderr, "dup2: %v\n", err) 19 continue 20 } 21 countLines(f, counts) 22 f.Close() 23 } 24 } 25 26 for line, n := range counts { 27 if n > 1 { 28 fmt.Printf("%d\t%s\n", n, line) 29 } 30 } 31} 32 33func countLines(f *os.File, counts map[string]int) { 34 input := bufio.NewScanner(f) 35 for input.Scan() { 36 counts[input.Text()]++ 37 } 38}
dup3.go
1package main 2 3import ( 4 "fmt" 5 "io/ioutil" 6 "os" 7 "strings" 8) 9 10func main() { 11 counts := make(map[string]int) 12 for _, filename := range os.Args[1:] { 13 data, err := ioutil.ReadFile(filename) 14 if err != nil { 15 fmt.Fprintf(os.Stderr, "dup3: %v\n", err) 16 continue 17 } 18 for _, line := range strings.Split(string(data), "\n") { 19 counts[line]++ 20 } 21 } 22 for line, n := range counts { 23 if n > 1 { 24 fmt.Printf("%d\t%s\n", n, line) 25 } 26 } 27}
fetch
fetch.go
1package main 2 3import ( 4 "fmt" 5 "io/ioutil" 6 "net/http" 7 "os" 8) 9 10func main() { 11 for _, url := range os.Args[1:] { 12 resp, err := http.Get(url) 13 if err != nil { 14 fmt.Fprintf(os.Stderr, "fetch: %v\n", err) 15 os.Exit(1) 16 } 17 b, err := ioutil.ReadAll(resp.Body) 18 resp.Body.Close() 19 if err != nil { 20 fmt.Fprintf(os.Stderr, "fetch: reading %s: %v\n", url, err) 21 os.Exit(1) 22 } 23 fmt.Printf("%s", b) 24 } 25}
fetchall.go
并发获取多个 URL
1package main 2 3import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 "net/http" 8 "os" 9 "time" 10) 11 12func main() { 13 start := time.Now() 14 ch := make(chan string) 15 for _, url := range os.Args[1:] { 16 go fetch(url, ch) 17 } 18 for range os.Args[1:] { 19 fmt.Println(<-ch) 20 } 21 fmt.Printf("%.2fs elapsed\n", time.Since(start).Seconds()) 22} 23 24func fetch(url string, ch chan<- string) { 25 start := time.Now() 26 resp, err := http.Get(url) 27 if err != nil { 28 ch <- fmt.Sprint(err) 29 return 30 } 31 nbytes, err := io.Copy(ioutil.Discard, resp.Body) 32 resp.Body.Close() 33 if err != nil { 34 ch <- fmt.Sprintf("while reading %s: %v", url, err) 35 return 36 } 37 secs := time.Since(start).Seconds() 38 ch <- fmt.Sprintf("%.2fs %7d %s", secs, nbytes, url) 39}
server
server1.go
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7) 8 9func main() { 10 http.HandleFunc("/", handler) 11 log.Fatal(http.ListenAndServe("localhost:8000", nil)) 12} 13 14func handler(w http.ResponseWriter, r *http.Request) { 15 fmt.Fprintf(w, "URL.Path = %q\n", r.URL.Path) 16}
server2.go
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "sync" 8) 9 10var mu sync.Mutex 11var count int 12 13func main() { 14 http.HandleFunc("/", handler) 15 http.HandleFunc("/count", counter) 16 log.Fatal(http.ListenAndServe("localhost:8000", nil)) 17} 18 19func handler(w http.ResponseWriter, r *http.Request) { 20 mu.Lock() 21 count++ 22 mu.Unlock() 23 fmt.Fprintf(w, "URL.Path = %q\n", r.URL.Path) 24} 25 26func counter(w http.ResponseWriter, r *http.Request) { 27 mu.Lock() 28 fmt.Fprintf(w, "Count %d\n", count) 29 mu.Unlock() 30}
server3.go
handler 函数会把请求的 http 头和请求的 form 数据都打印出来,这样可以使检查和调试这个服务更为方便
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7) 8 9func main() { 10 http.HandleFunc("/", handler) 11 log.Fatal(http.ListenAndServe("localhost:8000", nil)) 12} 13 14func handler(w http.ResponseWriter, r *http.Request) { 15 fmt.Fprintf(w, "%s %s %s\n", r.Method, r.URL, r.Proto) 16 for k, v := range r.Header { 17 fmt.Fprintf(w, "Header[%q] = %q\n", k, v) 18 } 19 fmt.Fprintf(w, "Host = %q\n", r.Host) 20 fmt.Fprintf(w, "RemoteAddr = %q\n", r.RemoteAddr) 21 if err := r.ParseForm(); err != nil { 22 log.Print(err) 23 } 24 for k, v := range r.Form { 25 fmt.Fprintf(w, "Form[%q] = %q\n", k, v) 26 } 27}
basename
basename1.go
1package main 2 3import "fmt" 4 5func basename(s string) string { 6 for i := len(s) - 1; i >= 0; i-- { 7 if s[i] == '/' { 8 s = s[i+1:] 9 break 10 } 11 } 12 13 for i := len(s) - 1; i >= 0; i-- { 14 if s[i] == '.' { 15 s = s[:i] 16 break 17 } 18 } 19 return s 20} 21 22func main() { 23 fmt.Println(basename("a/b/c.go")) 24 fmt.Println(basename("c.d.go")) 25 fmt.Println(basename("abc")) 26}
basename2.go
1package main 2 3import "fmt" 4import "strings" 5 6func basename(s string) string { 7 slash := strings.LastIndex(s, "/") 8 s = s[slash+1:] 9 if dot := strings.LastIndex(s, "."); dot >= 0 { 10 s = s[:dot] 11 } 12 return s 13} 14 15func main() { 16 fmt.Println(basename("a/b/c.go")) 17 fmt.Println(basename("c.d.go")) 18 fmt.Println(basename("abc")) 19}
comman
commma.go
1package main 2 3import "fmt" 4 5func comma(s string) string { 6 n := len(s) 7 if n <= 3 { 8 return s 9 } 10 return comma(s[:n-3]) + "," + s[n-3:] 11} 12 13func main() { 14 fmt.Println(comma("12345")) 15 fmt.Println(comma("12345678")) 16}
reverse
rev.go
1package main 2 3import "fmt" 4 5func reverse(s []int) { 6 for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { 7 s[i], s[j] = s[j], s[i] 8 } 9} 10 11func main() { 12 a := [...]int{0, 1, 2, 3, 4, 5} 13 reverse(a[:]) 14 fmt.Println(a) 15}
dedup
dedup.go
程序读取多行输入,但是只打印第一次出现的行
1package main 2 3import ( 4 "bufio" 5 "fmt" 6 "os" 7) 8 9func main() { 10 // a set of strings 11 seen := make(map[string]bool) 12 input := bufio.NewScanner(os.Stdin) 13 for input.Scan() { 14 line := input.Text() 15 if !seen[line] { 16 seen[line] = true 17 fmt.Println(line) 18 } 19 } 20 21 if err := input.Err(); err != nil { 22 fmt.Fprintf(os.Stderr, "dedup: %v\n", err) 23 os.Exit(1) 24 } 25}
treesort
treesort.go
使用二叉树实现插入排序
1package main 2 3import "fmt" 4 5func main() { 6 values := []int{5, 4, 3, 9, 21, 4} 7 Sort(values) 8 fmt.Println(values) 9} 10 11type tree struct { 12 value int 13 left, right *tree 14} 15 16func Sort(values []int) { 17 var root *tree 18 for _, v := range values { 19 root = add(root, v) 20 } 21 appendValues(values[:0], root) 22} 23 24func appendValues(values []int, t *tree) []int { 25 if t != nil { 26 values = appendValues(values, t.left) 27 values = append(values, t.value) 28 values = appendValues(values, t.right) 29 } 30 return values 31} 32 33func add(t *tree, value int) *tree { 34 if t == nil { 35 return &tree{value: value} 36 } 37 if value < t.value { 38 t.left = add(t.left, value) 39 } else { 40 t.right = add(t.right, value) 41 } 42 return t 43}
JSON
movie.go
1package main 2 3import ( 4 "encoding/json" 5 "fmt" 6 "log" 7) 8 9func main() { 10 data, err := json.Marshal(movies) 11 if err != nil { 12 log.Fatalf("JSON marshaling failed: %s", err) 13 } 14 fmt.Printf("%s\n", data) 15 16 data, err = json.MarshalIndent(movies, "", " ") 17 if err != nil { 18 log.Fatalf("JSON marshaling failed: %s", err) 19 } 20 fmt.Printf("%s\n", data) 21 22 var titles []struct{ Title string } 23 if err := json.Unmarshal(data, &titles); err != nil { 24 log.Fatalf("JSON unmarshaling failed: %s", err) 25 } 26 fmt.Println(titles) 27 28 var items []Movie 29 if err := json.Unmarshal(data, &items); err != nil { 30 log.Fatalf("JSON unmarshaling failed: %s", err) 31 } 32 fmt.Println(items) 33} 34 35type Movie struct { 36 Title string `json:"title"` 37 Year int `json:"released"` 38 Color bool `json:"color,omitempty"` 39 Actors []string 40} 41 42var movies = []Movie{ 43 { 44 Title: "Casablanca", 45 Year: 1942, 46 Color: false, 47 Actors: []string{"Humphrey Bogart", "Ingrid Bergman"}, 48 }, 49 { 50 Title: "Cool Hand Luke", 51 Year: 1967, 52 Color: true, 53 Actors: []string{"Paul Newman"}, 54 }, 55 { 56 Title: "Casablanca", 57 Year: 1942, 58 Color: true, 59 Actors: []string{"Steve McQueen", "Jacqueline Bisset"}, 60 }, 61}
模板
github.go
1package main 2 3import ( 4 "encoding/json" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 "os" 10 "strings" 11 "text/template" 12 "time" 13) 14 15func main() { 16 result, err := SearchIssues(os.Args[1:]) 17 if err != nil { 18 log.Fatal(err) 19 } 20 fmt.Printf("%d issues:\n", result.TotalCount) 21 for _, item := range result.Items { 22 fmt.Printf("#%-5d %9.9s %.55s\n", item.Number, item.User.Login, item.Title) 23 } 24 25 fmt.Println() 26 if err := report.Execute(os.Stdout, result); err != nil { 27 log.Fatal(err) 28 } 29} 30 31const templ = `{{.TotalCount}} issues: 32{{range .Items}}---------------------------------------- 33Number: {{.Number}} 34User: {{.User.Login}} 35Title: {{.Title | printf "%.64s"}} 36Age: {{.CreatedAt | daysAgo}} days 37{{end}}` 38 39var report = template.Must(template.New("issuelist"). 40 Funcs(template.FuncMap{"daysAgo": daysAgo}). 41 Parse(templ)) 42 43func daysAgo(t time.Time) int { 44 return int(time.Since(t).Hours() / 24) 45} 46 47const IssuesURL = "https://api.github.com/search/issues" 48 49type IssuesSearchResult struct { 50 TotalCount int `json:"total_count"` 51 Items []*Issue 52} 53 54type Issue struct { 55 Number int 56 HTMLURL string `json:"html_url"` 57 Title string 58 State string 59 User *User 60 CreatedAt time.Time `json:"created_at"` 61 Body string 62} 63 64type User struct { 65 Login string 66 HTMLURL string `json:"html_url"` 67} 68 69func SearchIssues(terms []string) (*IssuesSearchResult, error) { 70 q := url.QueryEscape(strings.Join(terms, " ")) 71 resp, err := http.Get(IssuesURL + "?q=" + q) 72 if err != nil { 73 return nil, err 74 } 75 76 if resp.StatusCode != http.StatusOK { 77 resp.Body.Close() 78 return nil, fmt.Errorf("search query failed: %s", resp.Status) 79 } 80 81 var result IssuesSearchResult 82 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 83 resp.Body.Close() 84 return nil, err 85 } 86 resp.Body.Close() 87 return &result, nil 88}
递归
findlinks1.go
1package main 2 3import ( 4 "fmt" 5 "os" 6 7 "golang.org/x/net/html" 8) 9 10func main() { 11 doc, err := html.Parse(os.Stdin) 12 if err != nil { 13 fmt.Fprintf(os.Stderr, "findlinks1: %v\n", err) 14 os.Exit(1) 15 } 16 for _, link := range visit(nil, doc) { 17 fmt.Println(link) 18 } 19 20 fmt.Println() 21 outline(nil, doc) 22} 23 24func visit(links []string, n *html.Node) []string { 25 if n.Type == html.ElementNode && n.Data == "a" { 26 for _, a := range n.Attr { 27 if a.Key == "href" { 28 links = append(links, a.Val) 29 } 30 } 31 } 32 for c := n.FirstChild; c != nil; c = c.NextSibling { 33 links = visit(links, c) 34 } 35 return links 36} 37 38func outline(stack []string, n *html.Node) { 39 if n.Type == html.ElementNode { 40 stack = append(stack, n.Data) 41 fmt.Println(stack) 42 } 43 for c := n.FirstChild; c != nil; c = c.NextSibling { 44 outline(stack, c) 45 } 46}
findlinks2.go
1package main 2 3import ( 4 "fmt" 5 "net/http" 6 "os" 7 8 "golang.org/x/net/html" 9) 10 11func main() { 12 for _, url := range os.Args[1:] { 13 links, err := findLinks(url) 14 if err != nil { 15 fmt.Fprintf(os.Stderr, "findlinks2: %v\n", err) 16 continue 17 } 18 for _, link := range links { 19 fmt.Println(link) 20 } 21 } 22} 23 24func visit(links []string, n *html.Node) []string { 25 if n.Type == html.ElementNode && n.Data == "a" { 26 for _, a := range n.Attr { 27 if a.Key == "href" { 28 links = append(links, a.Val) 29 } 30 } 31 } 32 for c := n.FirstChild; c != nil; c = c.NextSibling { 33 links = visit(links, c) 34 } 35 return links 36} 37 38func findLinks(url string) ([]string, error) { 39 resp, err := http.Get(url) 40 if err != nil { 41 return nil, err 42 } 43 if resp.StatusCode != http.StatusOK { 44 resp.Body.Close() 45 return nil, fmt.Errorf("getting %s: %s", url, resp.Status) 46 } 47 doc, err := html.Parse(resp.Body) 48 resp.Body.Close() 49 if err != nil { 50 return nil, fmt.Errorf("parsing %s as HTML: %v", url, err) 51 } 52 return visit(nil, doc), nil 53}
匿名函数
topsort.go
用深度优先搜索了整张图,获得了符合要求的课程序列
1package main 2 3import ( 4 "fmt" 5 "sort" 6) 7 8var perreqs = map[string][]string{ 9 "algorithms": {"data structures"}, 10 "calculus": {"linear algebra"}, 11 "compilers": { 12 "data structures", 13 "formal languages", 14 "computer organization", 15 }, 16 "data structures": {"discrete math"}, 17 "databases": {"data structures"}, 18 "discrete math": {"intro to programming"}, 19 "formal languages": {"discrete math"}, 20 "networks": {"operating systems"}, 21 "operating systems": {"data structures", "computer organization"}, 22 "programming languages": {"data structures", "computer organization"}, 23} 24 25func main() { 26 for i, course := range topoSort(perreqs) { 27 fmt.Printf("%d:\t%s\n", i+1, course) 28 } 29} 30 31func topoSort(m map[string][]string) []string { 32 var order []string 33 seen := make(map[string]bool) 34 var visitAll func(items []string) 35 visitAll = func(items []string) { 36 for _, item := range items { 37 if !seen[item] { 38 seen[item] = true 39 visitAll(m[item]) 40 order = append(order, item) 41 } 42 } 43 } 44 var keys []string 45 for key := range m { 46 keys = append(keys, key) 47 } 48 sort.Strings(keys) 49 visitAll(keys) 50 return order 51}
findlinks3.go
广度优先
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "os" 8 9 "golang.org/x/net/html" 10) 11 12func main() { 13 breadthFirst(crawl, os.Args[1:]) 14} 15 16func breadthFirst(f func(item string) []string, worklist []string) { 17 seen := make(map[string]bool) 18 for len(worklist) > 0 { 19 items := worklist 20 worklist = nil 21 for _, item := range items { 22 if !seen[item] { 23 seen[item] = true 24 worklist = append(worklist, f(item)...) 25 } 26 } 27 } 28} 29 30func crawl(url string) []string { 31 fmt.Println(url) 32 list, err := Extract(url) 33 if err != nil { 34 log.Print(err) 35 } 36 return list 37} 38 39func Extract(url string) ([]string, error) { 40 resp, err := http.Get(url) 41 if err != nil { 42 return nil, err 43 } 44 if resp.StatusCode != http.StatusOK { 45 resp.Body.Close() 46 return nil, fmt.Errorf("getting %s: %s", url, resp.Status) 47 } 48 doc, err := html.Parse(resp.Body) 49 resp.Body.Close() 50 if err != nil { 51 return nil, fmt.Errorf("parsing %s as HTML: %v", url, err) 52 } 53 var links []string 54 visitNode := func(n *html.Node) { 55 if n.Type == html.ElementNode && n.Data == "a" { 56 for _, a := range n.Attr { 57 if a.Key != "href" { 58 continue 59 } 60 link, err := resp.Request.URL.Parse(a.Val) 61 if err != nil { 62 continue 63 } 64 links = append(links, link.String()) 65 } 66 } 67 } 68 forEachNode(doc, visitNode, nil) 69 return links, nil 70} 71 72func forEachNode(n *html.Node, pre, post func(n *html.Node)) { 73 if pre != nil { 74 pre(n) 75 } 76 for c := n.FirstChild; c != nil; c = c.NextSibling { 77 forEachNode(c, pre, post) 78 } 79 if post != nil { 80 post(n) 81 } 82}
defer
两种写法对比
title1.go
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "os" 8 "strings" 9 10 "golang.org/x/net/html" 11) 12 13func main() { 14 for _, val := range os.Args[1:] { 15 if err := title(val); err != nil { 16 log.Print(err) 17 } 18 } 19} 20 21func title(url string) error { 22 resp, err := http.Get(url) 23 if err != nil { 24 return err 25 } 26 27 ct := resp.Header.Get("Content-Type") 28 if ct != "text/html" && !strings.HasPrefix(ct, "text/html;") { 29 resp.Body.Close() 30 return fmt.Errorf("%s has type %s, not text/html", url, ct) 31 } 32 33 doc, err := html.Parse(resp.Body) 34 resp.Body.Close() 35 if err != nil { 36 return fmt.Errorf("parsing %s as HTML: %v", url, err) 37 } 38 visitNode := func(n *html.Node) { 39 if n.Type == html.ElementNode && n.Data == "title" && n.FirstChild != nil { 40 fmt.Println(n.FirstChild.Data) 41 } 42 } 43 forEachNode(doc, visitNode, nil) 44 return nil 45} 46 47func forEachNode(n *html.Node, pre, post func(n *html.Node)) { 48 if pre != nil { 49 pre(n) 50 } 51 for c := n.FirstChild; c != nil; c = c.NextSibling { 52 forEachNode(c, pre, post) 53 } 54 if post != nil { 55 post(n) 56 } 57}
title2.go
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "os" 8 "strings" 9 10 "golang.org/x/net/html" 11) 12 13func main() { 14 for _, val := range os.Args[1:] { 15 if err := title(val); err != nil { 16 log.Print(err) 17 } 18 } 19} 20 21func title(url string) error { 22 resp, err := http.Get(url) 23 if err != nil { 24 return err 25 } 26 27 defer resp.Body.Close() 28 29 ct := resp.Header.Get("Content-Type") 30 if ct != "text/html" && !strings.HasPrefix(ct, "text/html;") { 31 return fmt.Errorf("%s has type %s, not text/html", url, ct) 32 } 33 34 doc, err := html.Parse(resp.Body) 35 if err != nil { 36 return fmt.Errorf("parsing %s as HTML: %v", url, err) 37 } 38 visitNode := func(n *html.Node) { 39 if n.Type == html.ElementNode && n.Data == "title" && n.FirstChild != nil { 40 fmt.Println(n.FirstChild.Data) 41 } 42 } 43 forEachNode(doc, visitNode, nil) 44 return nil 45} 46 47func forEachNode(n *html.Node, pre, post func(n *html.Node)) { 48 if pre != nil { 49 pre(n) 50 } 51 for c := n.FirstChild; c != nil; c = c.NextSibling { 52 forEachNode(c, pre, post) 53 } 54 if post != nil { 55 post(n) 56 } 57}
接口
bytecounter.go
实现 Write
方法
1package main 2 3import "fmt" 4 5type ByteCounter int 6 7func (c *ByteCounter) Write(p []byte) (int, error) { 8 *c += ByteCounter(len(p)) 9 return len(p), nil 10} 11 12func main() { 13 var c ByteCounter 14 c.Write([]byte("hello")) 15 fmt.Println(c) 16 c = 0 17 var name = "Dolly" 18 fmt.Fprintf(&c, "hello, %s", name) 19 fmt.Println(c) 20}
flag 接口
sleep.go
1package main 2 3import ( 4 "flag" 5 "fmt" 6 "time" 7) 8 9var period = flag.Duration("period", 1*time.Second, "sleep period") 10 11func main() { 12 flag.Parse() 13 fmt.Printf("Sleeping for %v", *period) 14 time.Sleep(*period) 15 fmt.Println() 16}
sort.Interface 接口
sorting.go
1package main 2 3import ( 4 "fmt" 5 "os" 6 "sort" 7 "text/tabwriter" 8 "time" 9) 10 11func main() { 12 printTracks(tracks) 13 14 sort.Sort(byArtist(tracks)) 15 printTracks(tracks) 16 17 sort.Sort(sort.Reverse(byArtist(tracks))) 18 printTracks(tracks) 19 20 sort.Sort(byYear(tracks)) 21 printTracks(tracks) 22 23 sort.Sort(customSort{tracks, func(x, y *Track) bool { 24 if x.Title != y.Title { 25 return x.Title < y.Title 26 } 27 if x.Year != y.Year { 28 return x.Year < y.Year 29 } 30 if x.Length != y.Length { 31 return x.Length < y.Length 32 } 33 return false 34 }}) 35 printTracks(tracks) 36 37 fmt.Println(sort.IsSorted(byYear(tracks))) 38} 39 40type byArtist []*Track 41 42func (x byArtist) Len() int { return len(x) } 43func (x byArtist) Less(i, j int) bool { return x[i].Artist < x[j].Artist } 44func (x byArtist) Swap(i, j int) { x[i], x[j] = x[j], x[i] } 45 46type byYear []*Track 47 48func (x byYear) Len() int { return len(x) } 49func (x byYear) Less(i, j int) bool { return x[i].Year < x[j].Year } 50func (x byYear) Swap(i, j int) { x[i], x[j] = x[j], x[i] } 51 52type customSort struct { 53 t []*Track 54 less func(x, y *Track) bool 55} 56 57func (x customSort) Len() int { return len(x.t) } 58func (x customSort) Less(i, j int) bool { return x.less(x.t[i], x.t[j]) } 59func (x customSort) Swap(i, j int) { x.t[i], x.t[j] = x.t[j], x.t[i] } 60 61type Track struct { 62 Title string 63 Artist string 64 Album string 65 Year int 66 Length time.Duration 67} 68 69var tracks = []*Track{ 70 {"Go", "Delilah", "From the Roots Up", 2012, length("3m38s")}, 71 {"Go", "Moby", "Moby", 1992, length("3m37s")}, 72 {"Go Ahead", "Alicia Keys", "As I Am", 2007, length("4m36s")}, 73 {"Ready 2 Go", "Martin Solveig", "Smash", 2011, length("4m24s")}, 74} 75 76func length(s string) time.Duration { 77 d, err := time.ParseDuration(s) 78 if err != nil { 79 panic(s) 80 } 81 return d 82} 83 84func printTracks(track []*Track) { 85 const format = "%v\t%v\t%v\t%v\t%v\t\n" 86 tw := new(tabwriter.Writer).Init(os.Stdout, 0, 8, 2, ' ', 0) 87 fmt.Fprintf(tw, format, "Title", "Artist", "Album", "Year", "Length") 88 fmt.Fprintf(tw, format, "-----", "------", "-----", "----", "------") 89 for _, t := range track { 90 fmt.Fprintf(tw, format, t.Title, t.Artist, t.Album, t.Year, t.Length) 91 } 92 tw.Flush() 93 fmt.Println() 94}
Goroutines
spinner.go
1package main 2 3import ( 4 "fmt" 5 "time" 6) 7 8func fib(x int) int { 9 if x < 2 { 10 return x 11 } 12 return fib(x-1) + fib(x-2) 13} 14 15func spinner(delay time.Duration) { 16 for { 17 for _, r := range `-\|/` { 18 fmt.Printf("\r%c", r) 19 time.Sleep(delay) 20 } 21 } 22} 23 24func main() { 25 go spinner(100 * time.Millisecond) 26 const n = 45 27 fibN := fib(n) 28 fmt.Printf("\rFibonacci(%d) = %d\n", n, fibN) 29}
clock1.go
1package main 2 3import ( 4 "io" 5 "log" 6 "net" 7 "time" 8) 9 10func main() { 11 listener, err := net.Listen("tcp", "localhost:8000") 12 if err != nil { 13 log.Fatal(err) 14 } 15 16 for { 17 conn, err := listener.Accept() 18 if err != nil { 19 log.Print(err) 20 continue 21 } 22 go handleConn(conn) 23 } 24} 25 26func handleConn(c net.Conn) { 27 defer c.Close() 28 for { 29 _, err := io.WriteString(c, time.Now().Format("15:04:05\n")) 30 if err != nil { 31 return 32 } 33 time.Sleep(1 * time.Second) 34 } 35}
netcat1.go
1package main 2 3import ( 4 "io" 5 "log" 6 "net" 7 "os" 8) 9 10func main() { 11 conn, err := net.Dial("tcp", "localhost:8000") 12 if err != nil { 13 log.Fatal(err) 14 } 15 defer conn.Close() 16 go mustCopy(os.Stdout, conn) 17 mustCopy(conn, os.Stdin) 18} 19 20func mustCopy(dst io.Writer, src io.Reader) { 21 if _, err := io.Copy(dst, src); err != nil { 22 log.Fatal(err) 23 } 24}
Channels
netcat3.go
1package main 2 3import ( 4 "io" 5 "log" 6 "net" 7 "os" 8) 9 10func main() { 11 conn, err := net.Dial("tcp", "localhost:8000") 12 if err != nil { 13 log.Fatal(err) 14 } 15 done := make(chan struct{}) 16 go func() { 17 io.Copy(os.Stdout, conn) 18 log.Println("done") 19 done <- struct{}{} 20 }() 21 mustCopy(conn, os.Stdin) 22 conn.Close() 23 <-done 24} 25 26func mustCopy(dst io.Writer, src io.Reader) { 27 if _, err := io.Copy(dst, src); err != nil { 28 log.Fatal(err) 29 } 30}
pipeline1.go
1package main 2 3import ( 4 "fmt" 5 "time" 6) 7 8func main() { 9 naturals := make(chan int) 10 squares := make(chan int) 11 12 // Counter 13 go func() { 14 for x := 0; ; x++ { 15 naturals <- x 16 time.Sleep(time.Millisecond * 500) 17 } 18 }() 19 20 // Squarer 21 go func() { 22 for { 23 x := <-naturals 24 squares <- x * x 25 } 26 }() 27 28 // Printer (in main goroutine) 29 for { 30 fmt.Println(<-squares) 31 } 32}
pipeline2.go
1package main 2 3import ( 4 "fmt" 5 "time" 6) 7 8func main() { 9 naturals := make(chan int) 10 squares := make(chan int) 11 12 // Counter 13 go func() { 14 for x := 0; x < 100; x++ { 15 naturals <- x 16 time.Sleep(time.Millisecond * 100) 17 } 18 close(naturals) 19 }() 20 21 // Squarer 22 go func() { 23 for x := range naturals { 24 squares <- x * x 25 } 26 close(squares) 27 }() 28 29 // Printer (in main goroutine) 30 for x := range squares { 31 fmt.Println(x) 32 } 33}
pipeline3.go
单方向
1package main 2 3import ( 4 "fmt" 5) 6 7func counter(out chan<- int) { 8 for x := 0; x < 100; x++ { 9 out <- x 10 } 11 close(out) 12} 13 14func squarer(out chan<- int, in <-chan int) { 15 for v := range in { 16 out <- v * v 17 } 18 close(out) 19} 20 21func printer(in <-chan int) { 22 for v := range in { 23 fmt.Println(v) 24 } 25} 26 27func main() { 28 naturals := make(chan int) 29 squares := make(chan int) 30 31 // Counter 32 go counter(naturals) 33 // Squarer 34 go squarer(squares, naturals) 35 // Printer (in main goroutine) 36 printer(squares) 37}
并发爬虫
crawl1.go
无限制并发
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "os" 8 9 "golang.org/x/net/html" 10) 11 12func main() { 13 worklist := make(chan []string) 14 go func() { worklist <- os.Args[1:] }() 15 seen := make(map[string]bool) 16 for list := range worklist { 17 for _, link := range list { 18 if !seen[link] { 19 seen[link] = true 20 go func(link string) { 21 worklist <- crawl(link) 22 }(link) 23 } 24 } 25 } 26} 27 28func crawl(url string) []string { 29 fmt.Println(url) 30 list, err := Extract(url) 31 if err != nil { 32 log.Print(err) 33 } 34 return list 35} 36 37func Extract(url string) ([]string, error) { 38 resp, err := http.Get(url) 39 if err != nil { 40 return nil, err 41 } 42 if resp.StatusCode != http.StatusOK { 43 resp.Body.Close() 44 return nil, fmt.Errorf("getting %s: %s", url, resp.Status) 45 } 46 doc, err := html.Parse(resp.Body) 47 resp.Body.Close() 48 if err != nil { 49 return nil, fmt.Errorf("parsing %s as HTML: %v", url, err) 50 } 51 var links []string 52 visitNode := func(n *html.Node) { 53 if n.Type == html.ElementNode && n.Data == "a" { 54 for _, a := range n.Attr { 55 if a.Key != "href" { 56 continue 57 } 58 link, err := resp.Request.URL.Parse(a.Val) 59 if err != nil { 60 continue 61 } 62 links = append(links, link.String()) 63 } 64 } 65 } 66 forEachNode(doc, visitNode, nil) 67 return links, nil 68} 69 70func forEachNode(n *html.Node, pre, post func(n *html.Node)) { 71 if pre != nil { 72 pre(n) 73 } 74 for c := n.FirstChild; c != nil; c = c.NextSibling { 75 forEachNode(c, pre, post) 76 } 77 if post != nil { 78 post(n) 79 } 80}
crawl2.go
将对 links.Extract 的调用操作用获取、释放 token 的操作包裹起来,来确保同一时间对其只有 20 个调用
1package main 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "os" 8 9 "golang.org/x/net/html" 10) 11 12func main() { 13 worklist := make(chan []string) 14 // number of pending sends to worklist 15 var n int 16 17 // start with the command-line arguments 18 n++ 19 go func() { worklist <- os.Args[1:] }() 20 21 // crawl the web concurrently 22 seen := make(map[string]bool) 23 24 for ; n > 0; n-- { 25 list := <-worklist 26 for _, link := range list { 27 if !seen[link] { 28 seen[link] = true 29 n++ 30 go func(link string) { 31 worklist <- crawl(link) 32 }(link) 33 } 34 } 35 } 36} 37 38var tokens = make(chan struct{}, 20) 39 40func crawl(url string) []string { 41 fmt.Println(url) 42 // acquire a token 43 tokens <- struct{}{} 44 list, err := Extract(url) 45 // release the token 46 <-tokens 47 if err != nil { 48 log.Print(err) 49 } 50 return list 51} 52 53func Extract(url string) ([]string, error) { 54 resp, err := http.Get(url) 55 if err != nil { 56 return nil, err 57 } 58 if resp.StatusCode != http.StatusOK { 59 resp.Body.Close() 60 return nil, fmt.Errorf("getting %s: %s", url, resp.Status) 61 } 62 doc, err := html.Parse(resp.Body) 63 resp.Body.Close() 64 if err != nil { 65 return nil, fmt.Errorf("parsing %s as HTML: %v", url, err) 66 } 67 var links []string 68 visitNode := func(n *html.Node) { 69 if n.Type == html.ElementNode && n.Data == "a" { 70 for _, a := range n.Attr { 71 if a.Key != "href" { 72 continue 73 } 74 link, err := resp.Request.URL.Parse(a.Val) 75 if err != nil { 76 continue 77 } 78 links = append(links, link.String()) 79 } 80 } 81 } 82 forEachNode(doc, visitNode, nil) 83 return links, nil 84} 85 86func forEachNode(n *html.Node, pre, post func(n *html.Node)) { 87 if pre != nil { 88 pre(n) 89 } 90 for c := n.FirstChild; c != nil; c = c.NextSibling { 91 forEachNode(c, pre, post) 92 } 93 if post != nil { 94 post(n) 95 } 96}
基于 select 的多路复用
countdown1.go
1package main 2 3import ( 4 "fmt" 5 "time" 6) 7 8func main() { 9 fmt.Println("Commencing countdown.") 10 tick := time.Tick(1 * time.Second) 11 for countdown := 10; countdown > 0; countdown-- { 12 fmt.Println(countdown) 13 <-tick 14 } 15 launch() 16} 17 18func launch() { 19 fmt.Println("launch") 20}
countdown2.go
1package main 2 3import ( 4 "fmt" 5 "os" 6 "time" 7) 8 9func main() { 10 abort := make(chan struct{}) 11 go func() { 12 os.Stdin.Read(make([]byte, 1)) 13 abort <- struct{}{} 14 }() 15 16 fmt.Println("Commencing countdown. Press return to abort.") 17 select { 18 case <-time.After(10 * time.Second): 19 case <-abort: 20 fmt.Println("Launch aborted!") 21 return 22 } 23 launch() 24} 25 26func launch() { 27 fmt.Println("launch") 28}
countdown3.go
1package main 2 3import ( 4 "fmt" 5 "os" 6 "time" 7) 8 9func main() { 10 abort := make(chan struct{}) 11 go func() { 12 os.Stdin.Read(make([]byte, 1)) 13 abort <- struct{}{} 14 }() 15 16 fmt.Println("Commencing countdown. Press return to abort.") 17 tick := time.Tick(time.Second) 18 for countdown := 10; countdown > 0; countdown-- { 19 fmt.Println(countdown) 20 select { 21 case <-tick: 22 case <-abort: 23 fmt.Println("Launch aborted!") 24 return 25 } 26 } 27 launch() 28} 29 30func launch() { 31 fmt.Println("launch") 32}
示例: 并发的目录遍历
du1.go
1package main 2 3import ( 4 "flag" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9) 10 11func main() { 12 // Determine the initial directories 13 flag.Parse() 14 roots := flag.Args() 15 if len(roots) == 0 { 16 roots = []string{"."} 17 } 18 19 // Traverse the file tree 20 fileSizes := make(chan int64) 21 go func() { 22 for _, root := range roots { 23 walkDir(root, fileSizes) 24 } 25 close(fileSizes) 26 }() 27 28 // Print the results 29 var nfiles, nbytes int64 30 for size := range fileSizes { 31 nfiles++ 32 nbytes += size 33 } 34 printDiskUsage(nfiles, nbytes) 35} 36 37func printDiskUsage(nfiles, nbytes int64) { 38 fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9) 39} 40 41func walkDir(dir string, fileSizes chan<- int64) { 42 for _, entry := range dirents(dir) { 43 if entry.IsDir() { 44 subdir := filepath.Join(dir, entry.Name()) 45 walkDir(subdir, fileSizes) 46 } else { 47 fileSizes <- entry.Size() 48 } 49 } 50} 51 52func dirents(dir string) []os.FileInfo { 53 entries, err := ioutil.ReadDir(dir) 54 if err != nil { 55 fmt.Fprintf(os.Stderr, "du1: %v\n", err) 56 return nil 57 } 58 return entries 59}
du2.go
1package main 2 3import ( 4 "flag" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9 "time" 10) 11 12var verbose = flag.Bool("v", false, "show verbose progress messages") 13 14func main() { 15 // Determine the initial directories 16 flag.Parse() 17 roots := flag.Args() 18 if len(roots) == 0 { 19 roots = []string{"."} 20 } 21 22 // Traverse the file tree 23 fileSizes := make(chan int64) 24 go func() { 25 for _, root := range roots { 26 walkDir(root, fileSizes) 27 } 28 close(fileSizes) 29 }() 30 31 var tick <-chan time.Time 32 if *verbose { 33 tick = time.Tick(500 * time.Millisecond) 34 } 35 36 // Print the results 37 var nfiles, nbytes int64 38 39loop: 40 for { 41 select { 42 case size, ok := <-fileSizes: 43 if !ok { 44 break loop 45 } 46 nfiles++ 47 nbytes += size 48 case <-tick: 49 printDiskUsage(nfiles, nbytes) 50 } 51 } 52 printDiskUsage(nfiles, nbytes) 53} 54 55func printDiskUsage(nfiles, nbytes int64) { 56 fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9) 57} 58 59func walkDir(dir string, fileSizes chan<- int64) { 60 for _, entry := range dirents(dir) { 61 if entry.IsDir() { 62 subdir := filepath.Join(dir, entry.Name()) 63 walkDir(subdir, fileSizes) 64 } else { 65 fileSizes <- entry.Size() 66 } 67 } 68} 69 70func dirents(dir string) []os.FileInfo { 71 entries, err := ioutil.ReadDir(dir) 72 if err != nil { 73 fmt.Fprintf(os.Stderr, "du1: %v\n", err) 74 return nil 75 } 76 return entries 77}
du3.go
1package main 2 3import ( 4 "flag" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9 "sync" 10 "time" 11) 12 13var verbose = flag.Bool("v", false, "show verbose progress messages") 14 15func main() { 16 // Determine the initial directories 17 flag.Parse() 18 roots := flag.Args() 19 if len(roots) == 0 { 20 roots = []string{"."} 21 } 22 23 // Traverse each root of the file tree in parallel 24 fileSizes := make(chan int64) 25 var n sync.WaitGroup 26 for _, root := range roots { 27 n.Add(1) 28 go walkDir(root, &n, fileSizes) 29 } 30 go func() { 31 n.Wait() 32 close(fileSizes) 33 }() 34 35 var tick <-chan time.Time 36 if *verbose { 37 tick = time.Tick(500 * time.Millisecond) 38 } 39 40 // Print the results 41 var nfiles, nbytes int64 42 43loop: 44 for { 45 select { 46 case size, ok := <-fileSizes: 47 if !ok { 48 break loop 49 } 50 nfiles++ 51 nbytes += size 52 case <-tick: 53 printDiskUsage(nfiles, nbytes) 54 } 55 } 56 printDiskUsage(nfiles, nbytes) 57} 58 59func printDiskUsage(nfiles, nbytes int64) { 60 fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9) 61} 62 63func walkDir(dir string, n *sync.WaitGroup, fileSizes chan<- int64) { 64 defer n.Done() 65 for _, entry := range dirents(dir) { 66 if entry.IsDir() { 67 n.Add(1) 68 subdir := filepath.Join(dir, entry.Name()) 69 go walkDir(subdir, n, fileSizes) 70 } else { 71 fileSizes <- entry.Size() 72 } 73 } 74} 75 76// sema is a counting semaphore for limiting concurrency in dirents 77var sema = make(chan struct{}, 20) 78 79func dirents(dir string) []os.FileInfo { 80 sema <- struct{}{} // acquire token 81 defer func() { <-sema }() // release token 82 entries, err := ioutil.ReadDir(dir) 83 if err != nil { 84 fmt.Fprintf(os.Stderr, "du1: %v\n", err) 85 return nil 86 } 87 return entries 88}
并发的退出
du4.go
1package main 2 3import ( 4 "flag" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9 "sync" 10 "time" 11) 12 13var verbose = flag.Bool("v", false, "show verbose progress messages") 14 15var done = make(chan struct{}) 16 17func cancelled() bool { 18 select { 19 case <-done: 20 return true 21 default: 22 return false 23 } 24} 25 26func main() { 27 // Determine the initial directories 28 flag.Parse() 29 roots := flag.Args() 30 if len(roots) == 0 { 31 roots = []string{"."} 32 } 33 34 // Cancel traversal when input is detected 35 go func() { 36 os.Stdin.Read(make([]byte, 1)) 37 close(done) 38 }() 39 40 // Traverse each root of the file tree in parallel 41 fileSizes := make(chan int64) 42 var n sync.WaitGroup 43 for _, root := range roots { 44 n.Add(1) 45 go walkDir(root, &n, fileSizes) 46 } 47 go func() { 48 n.Wait() 49 close(fileSizes) 50 }() 51 52 var tick <-chan time.Time 53 if *verbose { 54 tick = time.Tick(500 * time.Millisecond) 55 } 56 57 // Print the results 58 var nfiles, nbytes int64 59 60loop: 61 for { 62 select { 63 case <-done: 64 // Drain fileSizes to allow existing goroutines to finish 65 for range fileSizes { 66 // Do nothing 67 } 68 return 69 case size, ok := <-fileSizes: 70 if !ok { 71 break loop 72 } 73 nfiles++ 74 nbytes += size 75 case <-tick: 76 printDiskUsage(nfiles, nbytes) 77 } 78 } 79 printDiskUsage(nfiles, nbytes) 80} 81 82func printDiskUsage(nfiles, nbytes int64) { 83 fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9) 84} 85 86func walkDir(dir string, n *sync.WaitGroup, fileSizes chan<- int64) { 87 defer n.Done() 88 if cancelled() { 89 return 90 } 91 for _, entry := range dirents(dir) { 92 if entry.IsDir() { 93 n.Add(1) 94 subdir := filepath.Join(dir, entry.Name()) 95 go walkDir(subdir, n, fileSizes) 96 } else { 97 fileSizes <- entry.Size() 98 } 99 } 100} 101 102// sema is a counting semaphore for limiting concurrency in dirents 103var sema = make(chan struct{}, 20) 104 105func dirents(dir string) []os.FileInfo { 106 select { 107 case sema <- struct{}{}: // acquire token 108 case <-done: 109 return nil // cancelled 110 } 111 defer func() { <-sema }() // release token 112 entries, err := ioutil.ReadDir(dir) 113 if err != nil { 114 fmt.Fprintf(os.Stderr, "du1: %v\n", err) 115 return nil 116 } 117 return entries 118}
示例:聊天服务
chat.go
1package main 2 3import ( 4 "bufio" 5 "fmt" 6 "log" 7 "net" 8) 9 10func main() { 11 listener, err := net.Listen("tcp", "localhost:8000") 12 if err != nil { 13 log.Fatal(err) 14 } 15 16 go broadcaster() 17 18 for { 19 conn, err := listener.Accept() 20 if err != nil { 21 log.Print(err) 22 continue 23 } 24 go handleConn(conn) 25 } 26} 27 28type client chan<- string // an outgoing message channel 29 30var ( 31 entering = make(chan client) 32 leaving = make(chan client) 33 messages = make(chan string) // all incoming client messages 34) 35 36func broadcaster() { 37 clients := make(map[client]bool) // all connected clients 38 for { 39 select { 40 case msg := <-messages: 41 // Broadcast incoming message to all 42 // clients' outgoing message channels 43 for cli := range clients { 44 cli <- msg 45 } 46 47 case cli := <-entering: 48 clients[cli] = true 49 50 case cli := <-leaving: 51 delete(clients, cli) 52 close(cli) 53 } 54 } 55} 56 57func handleConn(conn net.Conn) { 58 ch := make(chan string) //outgoing client messages 59 go clientWriter(conn, ch) 60 61 who := conn.RemoteAddr().String() 62 ch <- "You are" + who 63 messages <- who + " has arrived" 64 entering <- ch 65 66 input := bufio.NewScanner(conn) 67 for input.Scan() { 68 messages <- who + ": " + input.Text() 69 } 70 // NOTE: ignoring potential errors from input.Err() 71 72 leaving <- ch 73 messages <- who + " has left" 74 conn.Close() 75} 76 77func clientWriter(conn net.Conn, ch <-chan string) { 78 for msg := range ch { 79 fmt.Fprintln(conn, msg) // NOTE: ignoring network errors 80 } 81}
竞争条件
bank1.go
balance 变量被限制在了 monitor goroutine 中
1package bank 2 3var deposits = make(chan int) // send amout to deposit 4var balances = make(chan int) // receive balance 5 6func Deposit(amount int) { deposits <- amount } 7func Balance() int { return <-balances } 8 9func teller() { 10 var balance int // balance is confined to teller goroutine 11 for { 12 select { 13 case amount := <-deposits: 14 balance += amount 15 case balances <- balance: 16 } 17 } 18} 19 20func init() { 21 go teller() // start the monitor goroutine 22}
bank2.go
我们可以用一个容量只有 1 的 channel 来保证最多只有一个 goroutine 在同一时刻访问一个共享变量。一个只能为 1 和 0 的信号量叫做二元信号量(binary semaphore)
1package bank 2 3var ( 4 sema = make(chan struct{}, 1) // a binary semaphore guarding balance 5 balance int 6) 7 8func Deposit(amount int) { 9 sema <- struct{}{} // acquire token 10 balance = balance + amount 11 <-sema // release token 12} 13 14func Balance() int { 15 sema <- struct{}{} // acquire token 16 b := balance 17 <-sema // release token 18 return b 19}
bank3.go
sync.Mutex 互斥锁
1package bank 2 3import "sync" 4 5var ( 6 mu sync.RWMutex // guards balance 7 balance int 8) 9 10func Deposit(amount int) { 11 mu.Lock() 12 defer mu.Unlock() 13 balance = balance + amount 14} 15 16func Balance() int { 17 mu.RLock() // readers lock 18 defer mu.RUnlock() 19 return balance 20}
示例: 并发的非阻塞缓存
memo1.go
1package memo 2 3type Memo struct { 4 f Func 5 cache map[string]result 6} 7 8type Func func(key string) (interface{}, error) 9 10type result struct { 11 value interface{} 12 err error 13} 14 15func New(f Func) *Memo { 16 return &Memo{f: f, cache: make(map[string]result)} 17} 18 19// not concurrency-safe 20func (memo *Memo) Get(key string) (interface{}, error) { 21 res, ok := memo.cache[key] 22 if !ok { 23 res.value, res.err = memo.f(key) 24 mome.cache[key] = res 25 } 26 return res.value, res.err 27}
memo2.go
1package memo 2 3import "sync" 4 5type Memo struct { 6 f Func 7 mu sync.Mutex // guards cache 8 cache map[string]result 9} 10 11type Func func(key string) (interface{}, error) 12 13type result struct { 14 value interface{} 15 err error 16} 17 18func New(f Func) *Memo { 19 return &Memo{f: f, cache: make(map[string]result)} 20} 21 22// Get is concurrency-safe 23// 完全丧失了并发的性能优点 24// 每次对 f 的调用期间都会持有锁 25// Get 将本来可以并行运行的 I/O 操作串行化了 26func (memo *Memo) Get(key string) (interface{}, error) { 27 memo.mu.Lock() 28 res, ok := memo.cache[key] 29 if !ok { 30 res.value, res.err = memo.f(key) 31 mome.cache[key] = res 32 } 33 memo.mu.Unlock() 34 return res.value, res.err 35}
memo3.go
1package memo 2 3import "sync" 4 5type Memo struct { 6 f Func 7 mu sync.Mutex // guards cache 8 cache map[string]result 9} 10 11type Func func(key string) (interface{}, error) 12 13type result struct { 14 value interface{} 15 err error 16} 17 18func New(f Func) *Memo { 19 return &Memo{f: f, cache: make(map[string]result)} 20} 21 22// Get is concurrency-safe 23// 两次获取锁的中间阶段,其他 goroutine 可以随意使用 cache 24// 使性能得到提升 25// 问题:同一时刻请求相同 url 时会重复请求,存在多余的工作 26func (memo *Memo) Get(key string) (interface{}, error) { 27 memo.mu.Lock() 28 res, ok := memo.cache[key] 29 memo.mu.Unlock() 30 if !ok { 31 res.value, res.err = memo.f(key) 32 33 // Between the two critical sections, several goroutines 34 // may race to compute f(key) and update the map 35 memo.mu.Lock() 36 mome.cache[key] = res 37 memo.mu.Unlock() 38 } 39 return res.value, res.err 40}
memo4.go
实现使用了一个互斥量来保护多个 goroutine 调用 Get 时的共享 map 变量
1package memo 2 3import "sync" 4 5type entry struct { 6 res result 7 ready chan struct{} // closed when res is ready 8} 9 10type Memo struct { 11 f Func 12 mu sync.Mutex // guards cache 13 cache map[string]*entry 14} 15 16type Func func(key string) (interface{}, error) 17 18type result struct { 19 value interface{} 20 err error 21} 22 23func New(f Func) *Memo { 24 return &Memo{f: f, cache: make(map[string]*entry)} 25} 26 27// Get is concurrency-safe 28func (memo *Memo) Get(key string) (interface{}, error) { 29 memo.mu.Lock() 30 e := memo.cache[key] 31 if e == nil { 32 // This is the first request for this key 33 // This goroutine becomes responsible for computing 34 // the value and broadcasting the ready condition 35 e = &entry{ready: make(chan struct{})} 36 memo.cache[key] = e 37 memo.mu.Unlock() 38 39 e.res.value, e.res.err = memo.f(key) 40 41 close(e.ready) // broadcast ready condition 42 } else { 43 // This is a repeat request for this key 44 memo.mu.Unlock() 45 46 <-e.ready // wait for ready condition 47 } 48 return e.res.value, e.res.err 49}
memo5.go
channel 通信的方式
1package memo 2 3// A request is a message requesting that the Func be applied to key 4type request struct { 5 key string 6 response chan<- result // the client wants a single result 7} 8 9type entry struct { 10 res result 11 ready chan struct{} // closed when res is ready 12} 13 14type Memo struct { 15 requests chan request 16} 17 18type Func func(key string) (interface{}, error) 19 20type result struct { 21 value interface{} 22 err error 23} 24 25// New returns a memoization of f 26// Clients must subsequently call Close 27func New(f Func) *Memo { 28 memo := &Memo{requests: make(chan request)} 29 go memo.server(f) 30 return memo 31} 32 33// Get is concurrency-safe 34func (memo *Memo) Get(key string) (interface{}, error) { 35 response := make(chan result) 36 memo.requests <- request{key, response} 37 res := <-response 38 return res.value, res.err 39} 40 41func (memo *Memo) Close() { close(memo.requests) } 42 43func (memo *Memo) server(f Func) { 44 cache := make(map[string]*entry) 45 for req := range mome.requests { 46 e := cache[req.key] 47 if e == nil { 48 // This is the first request for this key 49 e = &entry{ready: make(chan struct{})} 50 cache[req.key] = e 51 go e.call(f, req.key) 52 } 53 go e.deliver(req.response) 54 } 55} 56 57func (e *entry) call(f Func, key string) { 58 // Evaluate the function 59 e.res.value, e.res.err = f(key) 60 // Broadcast the ready condition 61 close(e.ready) 62} 63 64func (e *entry) deliver(response chan<- result) { 65 // Wait for the ready condition 66 <-e.ready 67 // Send the result to the client 68 response <- e.res 69}
测试
word.go
1// Package word provides utilities for word games. 2package word 3 4import "unicode" 5 6// IsPalindrome reports whether s reads the same forward and backward. 7// Letter case is ignored, as are non-letters. 8func IsPalindrome(s string) bool { 9 // letters := make([]rune, 0) 10 letters := make([]rune, 0, len(s)) 11 for _, r := range s { 12 if unicode.IsLetter(r) { 13 letters = append(letters, unicode.ToLower(r)) 14 } 15 } 16 n := len(letters) / 2 17 for i := 0; i < n; i++ { 18 if letters[i] != letters[len(letters)-1-i] { 19 return false 20 } 21 } 22 return true 23}
word_test.go
- 测试:
go test
- 测试覆盖率:
go test -coverprofile=c.out
(可输入go tool cover
查看帮助信息) - 基准测试:
go test -bench=.
、go test -bench=. -benchmem
1package word 2 3import "testing" 4 5func TestIsPalindrome(t *testing.T) { 6 var tests = []struct { 7 input string 8 want bool 9 }{ 10 {"", true}, 11 {"a", true}, 12 {"aa", true}, 13 {"ab", false}, 14 {"kayak", true}, 15 {"detartrated", true}, 16 {"A man, a plan, a canal: Panama", true}, 17 {"Evil I did dwell; lewd did I live.", true}, 18 {"Able was I ere I saw Elba", true}, 19 {"été", true}, 20 {"Et se resservir, ivresse reste.", true}, 21 {"palindrome", false}, // non-palindrome 22 {"desserts", false}, // semi-palindrome 23 } 24 for _, test := range tests { 25 if got := IsPalindrome(test.input); got != test.want { 26 t.Errorf("IsPalindrome(%q) = %v", test.input, got) 27 } 28 } 29} 30 31func BenchmarkIsPalindrome(b *testing.B) { 32 for i := 0; i < b.N; i++ { 33 IsPalindrome("A man, a plan, a canal: Panama") 34 } 35}