Golang Code Example

2018年11月12日

echo

Unix 里 echo 命令的一份实现,echo 把它的命令行参数打印成一行

echo1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6)
 7
 8func main() {
 9	var s, sep string
10	for i := 1; i < len(os.Args); i++ {
11		s += sep + os.Args[i]
12		sep = " "
13	}
14	fmt.Println(s)
15}

echo2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6)
 7
 8func main() {
 9	s, sep := "", ""
10	for _, arg := range os.Args[1:] {
11		s += sep + arg
12		sep = " "
13	}
14	fmt.Println(s)
15}

echo3.go

 1package main
 2
 3import (
 4	"fmt"
 5	"strings"
 6	"os"
 7)
 8
 9func main() {
10	fmt.Println(strings.Join(os.Args[1:], " "))
11}

dup

dup 的第一个版本打印标准输入中多次出现的行,以重复次数开头

dup1.go

 1package main
 2
 3import (
 4	"bufio"
 5	"fmt"
 6	"os"
 7)
 8
 9func main() {
10	counts := make(map[string]int)
11	input := bufio.NewScanner(os.Stdin)
12	for input.Scan() {
13		counts[input.Text()]++
14	}
15	for line, n := range counts {
16		if n > 1 {
17			fmt.Printf("%d\t%s\n", n, line)
18		}
19	}
20}

读取标准输入或是使用 os.Open 打开各个具名文件,并操作它们

dup2.go

 1package main
 2
 3import (
 4	"bufio"
 5	"fmt"
 6	"os"
 7)
 8
 9func main() {
10	counts := make(map[string]int)
11	files := os.Args[1:]
12	if len(files) == 0 {
13		countLines(os.Stdin, counts)
14	} else {
15		for _, arg := range files {
16			f, err := os.Open(arg)
17			if err != nil {
18				fmt.Fprintf(os.Stderr, "dup2: %v\n", err)
19				continue
20			}
21			countLines(f, counts)
22			f.Close()
23		}
24	}
25
26	for line, n := range counts {
27		if n > 1 {
28			fmt.Printf("%d\t%s\n", n, line)
29		}
30	}
31}
32
33func countLines(f *os.File, counts map[string]int) {
34	input := bufio.NewScanner(f)
35	for input.Scan() {
36		counts[input.Text()]++
37	}
38}

dup3.go

 1package main
 2
 3import (
 4	"fmt"
 5	"io/ioutil"
 6	"os"
 7	"strings"
 8)
 9
10func main() {
11	counts := make(map[string]int)
12	for _, filename := range os.Args[1:] {
13		data, err := ioutil.ReadFile(filename)
14		if err != nil {
15			fmt.Fprintf(os.Stderr, "dup3: %v\n", err)
16			continue
17		}
18		for _, line := range strings.Split(string(data), "\n") {
19			counts[line]++
20		}
21	}
22	for line, n := range counts {
23		if n > 1 {
24			fmt.Printf("%d\t%s\n", n, line)
25		}
26	}
27}

fetch

fetch.go

 1package main
 2
 3import (
 4	"fmt"
 5	"io/ioutil"
 6	"net/http"
 7	"os"
 8)
 9
10func main() {
11	for _, url := range os.Args[1:] {
12		resp, err := http.Get(url)
13		if err != nil {
14			fmt.Fprintf(os.Stderr, "fetch: %v\n", err)
15			os.Exit(1)
16		}
17		b, err := ioutil.ReadAll(resp.Body)
18		resp.Body.Close()
19		if err != nil {
20			fmt.Fprintf(os.Stderr, "fetch: reading %s: %v\n", url, err)
21			os.Exit(1)
22		}
23		fmt.Printf("%s", b)
24	}
25}

fetchall.go

并发获取多个 URL

 1package main
 2
 3import (
 4	"fmt"
 5	"io"
 6	"io/ioutil"
 7	"net/http"
 8	"os"
 9	"time"
10)
11
12func main() {
13	start := time.Now()
14	ch := make(chan string)
15	for _, url := range os.Args[1:] {
16		go fetch(url, ch)
17	}
18	for range os.Args[1:] {
19		fmt.Println(<-ch)
20	}
21	fmt.Printf("%.2fs elapsed\n", time.Since(start).Seconds())
22}
23
24func fetch(url string, ch chan<- string) {
25	start := time.Now()
26	resp, err := http.Get(url)
27	if err != nil {
28		ch <- fmt.Sprint(err)
29		return
30	}
31	nbytes, err := io.Copy(ioutil.Discard, resp.Body)
32	resp.Body.Close()
33	if err != nil {
34		ch <- fmt.Sprintf("while reading %s: %v", url, err)
35		return
36	}
37	secs := time.Since(start).Seconds()
38	ch <- fmt.Sprintf("%.2fs  %7d  %s", secs, nbytes, url)
39}

server

server1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7)
 8
 9func main() {
10	http.HandleFunc("/", handler)
11	log.Fatal(http.ListenAndServe("localhost:8000", nil))
12}
13
14func handler(w http.ResponseWriter, r *http.Request) {
15	fmt.Fprintf(w, "URL.Path = %q\n", r.URL.Path)
16}

server2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"sync"
 8)
 9
10var mu sync.Mutex
11var count int
12
13func main() {
14	http.HandleFunc("/", handler)
15	http.HandleFunc("/count", counter)
16	log.Fatal(http.ListenAndServe("localhost:8000", nil))
17}
18
19func handler(w http.ResponseWriter, r *http.Request) {
20	mu.Lock()
21	count++
22	mu.Unlock()
23	fmt.Fprintf(w, "URL.Path = %q\n", r.URL.Path)
24}
25
26func counter(w http.ResponseWriter, r *http.Request) {
27	mu.Lock()
28	fmt.Fprintf(w, "Count %d\n", count)
29	mu.Unlock()
30}

server3.go

handler 函数会把请求的 http 头和请求的 form 数据都打印出来,这样可以使检查和调试这个服务更为方便

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7)
 8
 9func main() {
10	http.HandleFunc("/", handler)
11	log.Fatal(http.ListenAndServe("localhost:8000", nil))
12}
13
14func handler(w http.ResponseWriter, r *http.Request) {
15	fmt.Fprintf(w, "%s %s %s\n", r.Method, r.URL, r.Proto)
16	for k, v := range r.Header {
17		fmt.Fprintf(w, "Header[%q] = %q\n", k, v)
18	}
19	fmt.Fprintf(w, "Host = %q\n", r.Host)
20	fmt.Fprintf(w, "RemoteAddr = %q\n", r.RemoteAddr)
21	if err := r.ParseForm(); err != nil {
22		log.Print(err)
23	}
24	for k, v := range r.Form {
25		fmt.Fprintf(w, "Form[%q] = %q\n", k, v)
26	}
27}

basename

basename1.go

 1package main
 2
 3import "fmt"
 4
 5func basename(s string) string {
 6	for i := len(s) - 1; i >= 0; i-- {
 7		if s[i] == '/' {
 8			s = s[i+1:]
 9			break
10		}
11	}
12
13	for i := len(s) - 1; i >= 0; i-- {
14		if s[i] == '.' {
15			s = s[:i]
16			break
17		}
18	}
19	return s
20}
21
22func main() {
23	fmt.Println(basename("a/b/c.go"))
24	fmt.Println(basename("c.d.go"))
25	fmt.Println(basename("abc"))
26}

basename2.go

 1package main
 2
 3import "fmt"
 4import "strings"
 5
 6func basename(s string) string {
 7	slash := strings.LastIndex(s, "/")
 8	s = s[slash+1:]
 9	if dot := strings.LastIndex(s, "."); dot >= 0 {
10		s = s[:dot]
11	}
12	return s
13}
14
15func main() {
16	fmt.Println(basename("a/b/c.go"))
17	fmt.Println(basename("c.d.go"))
18	fmt.Println(basename("abc"))
19}

comman

commma.go

 1package main
 2
 3import "fmt"
 4
 5func comma(s string) string {
 6	n := len(s)
 7	if n <= 3 {
 8		return s
 9	}
10	return comma(s[:n-3]) + "," + s[n-3:]
11}
12
13func main() {
14	fmt.Println(comma("12345"))
15	fmt.Println(comma("12345678"))
16}

reverse

rev.go

 1package main
 2
 3import "fmt"
 4
 5func reverse(s []int) {
 6	for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
 7		s[i], s[j] = s[j], s[i]
 8	}
 9}
10
11func main() {
12	a := [...]int{0, 1, 2, 3, 4, 5}
13	reverse(a[:])
14	fmt.Println(a)
15}

dedup

dedup.go

程序读取多行输入,但是只打印第一次出现的行

 1package main
 2
 3import (
 4	"bufio"
 5	"fmt"
 6	"os"
 7)
 8
 9func main() {
10	// a set of strings
11	seen := make(map[string]bool)
12	input := bufio.NewScanner(os.Stdin)
13	for input.Scan() {
14		line := input.Text()
15		if !seen[line] {
16			seen[line] = true
17			fmt.Println(line)
18		}
19	}
20
21	if err := input.Err(); err != nil {
22		fmt.Fprintf(os.Stderr, "dedup: %v\n", err)
23		os.Exit(1)
24	}
25}

treesort

treesort.go

使用二叉树实现插入排序

 1package main
 2
 3import "fmt"
 4
 5func main() {
 6	values := []int{5, 4, 3, 9, 21, 4}
 7	Sort(values)
 8	fmt.Println(values)
 9}
10
11type tree struct {
12	value       int
13	left, right *tree
14}
15
16func Sort(values []int) {
17	var root *tree
18	for _, v := range values {
19		root = add(root, v)
20	}
21	appendValues(values[:0], root)
22}
23
24func appendValues(values []int, t *tree) []int {
25	if t != nil {
26		values = appendValues(values, t.left)
27		values = append(values, t.value)
28		values = appendValues(values, t.right)
29	}
30	return values
31}
32
33func add(t *tree, value int) *tree {
34	if t == nil {
35		return &tree{value: value}
36	}
37	if value < t.value {
38		t.left = add(t.left, value)
39	} else {
40		t.right = add(t.right, value)
41	}
42	return t
43}

JSON

movie.go

 1package main
 2
 3import (
 4	"encoding/json"
 5	"fmt"
 6	"log"
 7)
 8
 9func main() {
10	data, err := json.Marshal(movies)
11	if err != nil {
12		log.Fatalf("JSON marshaling failed: %s", err)
13	}
14	fmt.Printf("%s\n", data)
15
16	data, err = json.MarshalIndent(movies, "", "  ")
17	if err != nil {
18		log.Fatalf("JSON marshaling failed: %s", err)
19	}
20	fmt.Printf("%s\n", data)
21
22	var titles []struct{ Title string }
23	if err := json.Unmarshal(data, &titles); err != nil {
24		log.Fatalf("JSON unmarshaling failed: %s", err)
25	}
26	fmt.Println(titles)
27
28	var items []Movie
29	if err := json.Unmarshal(data, &items); err != nil {
30		log.Fatalf("JSON unmarshaling failed: %s", err)
31	}
32	fmt.Println(items)
33}
34
35type Movie struct {
36	Title  string `json:"title"`
37	Year   int    `json:"released"`
38	Color  bool   `json:"color,omitempty"`
39	Actors []string
40}
41
42var movies = []Movie{
43	{
44		Title:  "Casablanca",
45		Year:   1942,
46		Color:  false,
47		Actors: []string{"Humphrey Bogart", "Ingrid Bergman"},
48	},
49	{
50		Title:  "Cool Hand Luke",
51		Year:   1967,
52		Color:  true,
53		Actors: []string{"Paul Newman"},
54	},
55	{
56		Title:  "Casablanca",
57		Year:   1942,
58		Color:  true,
59		Actors: []string{"Steve McQueen", "Jacqueline Bisset"},
60	},
61}

模板

github.go

 1package main
 2
 3import (
 4	"encoding/json"
 5	"fmt"
 6	"log"
 7	"net/http"
 8	"net/url"
 9	"os"
10	"strings"
11	"text/template"
12	"time"
13)
14
15func main() {
16	result, err := SearchIssues(os.Args[1:])
17	if err != nil {
18		log.Fatal(err)
19	}
20	fmt.Printf("%d issues:\n", result.TotalCount)
21	for _, item := range result.Items {
22		fmt.Printf("#%-5d %9.9s %.55s\n", item.Number, item.User.Login, item.Title)
23	}
24
25	fmt.Println()
26	if err := report.Execute(os.Stdout, result); err != nil {
27		log.Fatal(err)
28	}
29}
30
31const templ = `{{.TotalCount}} issues:
32{{range .Items}}----------------------------------------
33Number: {{.Number}}
34User: {{.User.Login}}
35Title: {{.Title | printf "%.64s"}}
36Age: {{.CreatedAt | daysAgo}} days
37{{end}}`
38
39var report = template.Must(template.New("issuelist").
40	Funcs(template.FuncMap{"daysAgo": daysAgo}).
41	Parse(templ))
42
43func daysAgo(t time.Time) int {
44	return int(time.Since(t).Hours() / 24)
45}
46
47const IssuesURL = "https://api.github.com/search/issues"
48
49type IssuesSearchResult struct {
50	TotalCount int `json:"total_count"`
51	Items      []*Issue
52}
53
54type Issue struct {
55	Number    int
56	HTMLURL   string `json:"html_url"`
57	Title     string
58	State     string
59	User      *User
60	CreatedAt time.Time `json:"created_at"`
61	Body      string
62}
63
64type User struct {
65	Login   string
66	HTMLURL string `json:"html_url"`
67}
68
69func SearchIssues(terms []string) (*IssuesSearchResult, error) {
70	q := url.QueryEscape(strings.Join(terms, " "))
71	resp, err := http.Get(IssuesURL + "?q=" + q)
72	if err != nil {
73		return nil, err
74	}
75
76	if resp.StatusCode != http.StatusOK {
77		resp.Body.Close()
78		return nil, fmt.Errorf("search query failed: %s", resp.Status)
79	}
80
81	var result IssuesSearchResult
82	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
83		resp.Body.Close()
84		return nil, err
85	}
86	resp.Body.Close()
87	return &result, nil
88}

递归

findlinks1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6
 7	"golang.org/x/net/html"
 8)
 9
10func main() {
11	doc, err := html.Parse(os.Stdin)
12	if err != nil {
13		fmt.Fprintf(os.Stderr, "findlinks1: %v\n", err)
14		os.Exit(1)
15	}
16	for _, link := range visit(nil, doc) {
17		fmt.Println(link)
18	}
19
20	fmt.Println()
21	outline(nil, doc)
22}
23
24func visit(links []string, n *html.Node) []string {
25	if n.Type == html.ElementNode && n.Data == "a" {
26		for _, a := range n.Attr {
27			if a.Key == "href" {
28				links = append(links, a.Val)
29			}
30		}
31	}
32	for c := n.FirstChild; c != nil; c = c.NextSibling {
33		links = visit(links, c)
34	}
35	return links
36}
37
38func outline(stack []string, n *html.Node) {
39	if n.Type == html.ElementNode {
40		stack = append(stack, n.Data)
41		fmt.Println(stack)
42	}
43	for c := n.FirstChild; c != nil; c = c.NextSibling {
44		outline(stack, c)
45	}
46}

findlinks2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"net/http"
 6	"os"
 7
 8	"golang.org/x/net/html"
 9)
10
11func main() {
12	for _, url := range os.Args[1:] {
13		links, err := findLinks(url)
14		if err != nil {
15			fmt.Fprintf(os.Stderr, "findlinks2: %v\n", err)
16			continue
17		}
18		for _, link := range links {
19			fmt.Println(link)
20		}
21	}
22}
23
24func visit(links []string, n *html.Node) []string {
25	if n.Type == html.ElementNode && n.Data == "a" {
26		for _, a := range n.Attr {
27			if a.Key == "href" {
28				links = append(links, a.Val)
29			}
30		}
31	}
32	for c := n.FirstChild; c != nil; c = c.NextSibling {
33		links = visit(links, c)
34	}
35	return links
36}
37
38func findLinks(url string) ([]string, error) {
39	resp, err := http.Get(url)
40	if err != nil {
41		return nil, err
42	}
43	if resp.StatusCode != http.StatusOK {
44		resp.Body.Close()
45		return nil, fmt.Errorf("getting %s: %s", url, resp.Status)
46	}
47	doc, err := html.Parse(resp.Body)
48	resp.Body.Close()
49	if err != nil {
50		return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
51	}
52	return visit(nil, doc), nil
53}

匿名函数

topsort.go

用深度优先搜索了整张图,获得了符合要求的课程序列

 1package main
 2
 3import (
 4	"fmt"
 5	"sort"
 6)
 7
 8var perreqs = map[string][]string{
 9	"algorithms": {"data structures"},
10	"calculus":   {"linear algebra"},
11	"compilers": {
12		"data structures",
13		"formal languages",
14		"computer organization",
15	},
16	"data structures":       {"discrete math"},
17	"databases":             {"data structures"},
18	"discrete math":         {"intro to programming"},
19	"formal languages":      {"discrete math"},
20	"networks":              {"operating systems"},
21	"operating systems":     {"data structures", "computer organization"},
22	"programming languages": {"data structures", "computer organization"},
23}
24
25func main() {
26	for i, course := range topoSort(perreqs) {
27		fmt.Printf("%d:\t%s\n", i+1, course)
28	}
29}
30
31func topoSort(m map[string][]string) []string {
32	var order []string
33	seen := make(map[string]bool)
34	var visitAll func(items []string)
35	visitAll = func(items []string) {
36		for _, item := range items {
37			if !seen[item] {
38				seen[item] = true
39				visitAll(m[item])
40				order = append(order, item)
41			}
42		}
43	}
44	var keys []string
45	for key := range m {
46		keys = append(keys, key)
47	}
48	sort.Strings(keys)
49	visitAll(keys)
50	return order
51}

findlinks3.go

广度优先

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"os"
 8
 9	"golang.org/x/net/html"
10)
11
12func main() {
13	breadthFirst(crawl, os.Args[1:])
14}
15
16func breadthFirst(f func(item string) []string, worklist []string) {
17	seen := make(map[string]bool)
18	for len(worklist) > 0 {
19		items := worklist
20		worklist = nil
21		for _, item := range items {
22			if !seen[item] {
23				seen[item] = true
24				worklist = append(worklist, f(item)...)
25			}
26		}
27	}
28}
29
30func crawl(url string) []string {
31	fmt.Println(url)
32	list, err := Extract(url)
33	if err != nil {
34		log.Print(err)
35	}
36	return list
37}
38
39func Extract(url string) ([]string, error) {
40	resp, err := http.Get(url)
41	if err != nil {
42		return nil, err
43	}
44	if resp.StatusCode != http.StatusOK {
45		resp.Body.Close()
46		return nil, fmt.Errorf("getting %s: %s", url, resp.Status)
47	}
48	doc, err := html.Parse(resp.Body)
49	resp.Body.Close()
50	if err != nil {
51		return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
52	}
53	var links []string
54	visitNode := func(n *html.Node) {
55		if n.Type == html.ElementNode && n.Data == "a" {
56			for _, a := range n.Attr {
57				if a.Key != "href" {
58					continue
59				}
60				link, err := resp.Request.URL.Parse(a.Val)
61				if err != nil {
62					continue
63				}
64				links = append(links, link.String())
65			}
66		}
67	}
68	forEachNode(doc, visitNode, nil)
69	return links, nil
70}
71
72func forEachNode(n *html.Node, pre, post func(n *html.Node)) {
73	if pre != nil {
74		pre(n)
75	}
76	for c := n.FirstChild; c != nil; c = c.NextSibling {
77		forEachNode(c, pre, post)
78	}
79	if post != nil {
80		post(n)
81	}
82}

defer

两种写法对比

title1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"os"
 8	"strings"
 9
10	"golang.org/x/net/html"
11)
12
13func main() {
14	for _, val := range os.Args[1:] {
15		if err := title(val); err != nil {
16			log.Print(err)
17		}
18	}
19}
20
21func title(url string) error {
22	resp, err := http.Get(url)
23	if err != nil {
24		return err
25	}
26
27	ct := resp.Header.Get("Content-Type")
28	if ct != "text/html" && !strings.HasPrefix(ct, "text/html;") {
29		resp.Body.Close()
30		return fmt.Errorf("%s has type %s, not text/html", url, ct)
31	}
32
33	doc, err := html.Parse(resp.Body)
34	resp.Body.Close()
35	if err != nil {
36		return fmt.Errorf("parsing %s as HTML: %v", url, err)
37	}
38	visitNode := func(n *html.Node) {
39		if n.Type == html.ElementNode && n.Data == "title" && n.FirstChild != nil {
40			fmt.Println(n.FirstChild.Data)
41		}
42	}
43	forEachNode(doc, visitNode, nil)
44	return nil
45}
46
47func forEachNode(n *html.Node, pre, post func(n *html.Node)) {
48	if pre != nil {
49		pre(n)
50	}
51	for c := n.FirstChild; c != nil; c = c.NextSibling {
52		forEachNode(c, pre, post)
53	}
54	if post != nil {
55		post(n)
56	}
57}

title2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"os"
 8	"strings"
 9
10	"golang.org/x/net/html"
11)
12
13func main() {
14	for _, val := range os.Args[1:] {
15		if err := title(val); err != nil {
16			log.Print(err)
17		}
18	}
19}
20
21func title(url string) error {
22	resp, err := http.Get(url)
23	if err != nil {
24		return err
25	}
26
27	defer resp.Body.Close()
28
29	ct := resp.Header.Get("Content-Type")
30	if ct != "text/html" && !strings.HasPrefix(ct, "text/html;") {
31		return fmt.Errorf("%s has type %s, not text/html", url, ct)
32	}
33
34	doc, err := html.Parse(resp.Body)
35	if err != nil {
36		return fmt.Errorf("parsing %s as HTML: %v", url, err)
37	}
38	visitNode := func(n *html.Node) {
39		if n.Type == html.ElementNode && n.Data == "title" && n.FirstChild != nil {
40			fmt.Println(n.FirstChild.Data)
41		}
42	}
43	forEachNode(doc, visitNode, nil)
44	return nil
45}
46
47func forEachNode(n *html.Node, pre, post func(n *html.Node)) {
48	if pre != nil {
49		pre(n)
50	}
51	for c := n.FirstChild; c != nil; c = c.NextSibling {
52		forEachNode(c, pre, post)
53	}
54	if post != nil {
55		post(n)
56	}
57}

接口

bytecounter.go

实现 Write 方法

 1package main
 2
 3import "fmt"
 4
 5type ByteCounter int
 6
 7func (c *ByteCounter) Write(p []byte) (int, error) {
 8	*c += ByteCounter(len(p))
 9	return len(p), nil
10}
11
12func main() {
13	var c ByteCounter
14	c.Write([]byte("hello"))
15	fmt.Println(c)
16	c = 0
17	var name = "Dolly"
18	fmt.Fprintf(&c, "hello, %s", name)
19	fmt.Println(c)
20}

flag 接口

sleep.go

 1package main
 2
 3import (
 4	"flag"
 5	"fmt"
 6	"time"
 7)
 8
 9var period = flag.Duration("period", 1*time.Second, "sleep period")
10
11func main() {
12	flag.Parse()
13	fmt.Printf("Sleeping for %v", *period)
14	time.Sleep(*period)
15	fmt.Println()
16}

sort.Interface 接口

sorting.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6	"sort"
 7	"text/tabwriter"
 8	"time"
 9)
10
11func main() {
12	printTracks(tracks)
13
14	sort.Sort(byArtist(tracks))
15	printTracks(tracks)
16
17	sort.Sort(sort.Reverse(byArtist(tracks)))
18	printTracks(tracks)
19
20	sort.Sort(byYear(tracks))
21	printTracks(tracks)
22
23	sort.Sort(customSort{tracks, func(x, y *Track) bool {
24		if x.Title != y.Title {
25			return x.Title < y.Title
26		}
27		if x.Year != y.Year {
28			return x.Year < y.Year
29		}
30		if x.Length != y.Length {
31			return x.Length < y.Length
32		}
33		return false
34	}})
35	printTracks(tracks)
36
37	fmt.Println(sort.IsSorted(byYear(tracks)))
38}
39
40type byArtist []*Track
41
42func (x byArtist) Len() int           { return len(x) }
43func (x byArtist) Less(i, j int) bool { return x[i].Artist < x[j].Artist }
44func (x byArtist) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
45
46type byYear []*Track
47
48func (x byYear) Len() int           { return len(x) }
49func (x byYear) Less(i, j int) bool { return x[i].Year < x[j].Year }
50func (x byYear) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
51
52type customSort struct {
53	t    []*Track
54	less func(x, y *Track) bool
55}
56
57func (x customSort) Len() int           { return len(x.t) }
58func (x customSort) Less(i, j int) bool { return x.less(x.t[i], x.t[j]) }
59func (x customSort) Swap(i, j int)      { x.t[i], x.t[j] = x.t[j], x.t[i] }
60
61type Track struct {
62	Title  string
63	Artist string
64	Album  string
65	Year   int
66	Length time.Duration
67}
68
69var tracks = []*Track{
70	{"Go", "Delilah", "From the Roots Up", 2012, length("3m38s")},
71	{"Go", "Moby", "Moby", 1992, length("3m37s")},
72	{"Go Ahead", "Alicia Keys", "As I Am", 2007, length("4m36s")},
73	{"Ready 2 Go", "Martin Solveig", "Smash", 2011, length("4m24s")},
74}
75
76func length(s string) time.Duration {
77	d, err := time.ParseDuration(s)
78	if err != nil {
79		panic(s)
80	}
81	return d
82}
83
84func printTracks(track []*Track) {
85	const format = "%v\t%v\t%v\t%v\t%v\t\n"
86	tw := new(tabwriter.Writer).Init(os.Stdout, 0, 8, 2, ' ', 0)
87	fmt.Fprintf(tw, format, "Title", "Artist", "Album", "Year", "Length")
88	fmt.Fprintf(tw, format, "-----", "------", "-----", "----", "------")
89	for _, t := range track {
90		fmt.Fprintf(tw, format, t.Title, t.Artist, t.Album, t.Year, t.Length)
91	}
92	tw.Flush()
93	fmt.Println()
94}

Goroutines

spinner.go

 1package main
 2
 3import (
 4	"fmt"
 5	"time"
 6)
 7
 8func fib(x int) int {
 9	if x < 2 {
10		return x
11	}
12	return fib(x-1) + fib(x-2)
13}
14
15func spinner(delay time.Duration) {
16	for {
17		for _, r := range `-\|/` {
18			fmt.Printf("\r%c", r)
19			time.Sleep(delay)
20		}
21	}
22}
23
24func main() {
25	go spinner(100 * time.Millisecond)
26	const n = 45
27	fibN := fib(n)
28	fmt.Printf("\rFibonacci(%d) = %d\n", n, fibN)
29}

clock1.go

 1package main
 2
 3import (
 4	"io"
 5	"log"
 6	"net"
 7	"time"
 8)
 9
10func main() {
11	listener, err := net.Listen("tcp", "localhost:8000")
12	if err != nil {
13		log.Fatal(err)
14	}
15
16	for {
17		conn, err := listener.Accept()
18		if err != nil {
19			log.Print(err)
20			continue
21		}
22		go handleConn(conn)
23	}
24}
25
26func handleConn(c net.Conn) {
27	defer c.Close()
28	for {
29		_, err := io.WriteString(c, time.Now().Format("15:04:05\n"))
30		if err != nil {
31			return
32		}
33		time.Sleep(1 * time.Second)
34	}
35}

netcat1.go

 1package main
 2
 3import (
 4	"io"
 5	"log"
 6	"net"
 7	"os"
 8)
 9
10func main() {
11	conn, err := net.Dial("tcp", "localhost:8000")
12	if err != nil {
13		log.Fatal(err)
14	}
15	defer conn.Close()
16	go mustCopy(os.Stdout, conn)
17	mustCopy(conn, os.Stdin)
18}
19
20func mustCopy(dst io.Writer, src io.Reader) {
21	if _, err := io.Copy(dst, src); err != nil {
22		log.Fatal(err)
23	}
24}

Channels

netcat3.go

 1package main
 2
 3import (
 4	"io"
 5	"log"
 6	"net"
 7	"os"
 8)
 9
10func main() {
11	conn, err := net.Dial("tcp", "localhost:8000")
12	if err != nil {
13		log.Fatal(err)
14	}
15	done := make(chan struct{})
16	go func() {
17		io.Copy(os.Stdout, conn)
18		log.Println("done")
19		done <- struct{}{}
20	}()
21	mustCopy(conn, os.Stdin)
22	conn.Close()
23	<-done
24}
25
26func mustCopy(dst io.Writer, src io.Reader) {
27	if _, err := io.Copy(dst, src); err != nil {
28		log.Fatal(err)
29	}
30}

pipeline1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"time"
 6)
 7
 8func main() {
 9	naturals := make(chan int)
10	squares := make(chan int)
11
12	// Counter
13	go func() {
14		for x := 0; ; x++ {
15			naturals <- x
16			time.Sleep(time.Millisecond * 500)
17		}
18	}()
19
20	// Squarer
21	go func() {
22		for {
23			x := <-naturals
24			squares <- x * x
25		}
26	}()
27
28	// Printer (in main goroutine)
29	for {
30		fmt.Println(<-squares)
31	}
32}

pipeline2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"time"
 6)
 7
 8func main() {
 9	naturals := make(chan int)
10	squares := make(chan int)
11
12	// Counter
13	go func() {
14		for x := 0; x < 100; x++ {
15			naturals <- x
16			time.Sleep(time.Millisecond * 100)
17		}
18		close(naturals)
19	}()
20
21	// Squarer
22	go func() {
23		for x := range naturals {
24			squares <- x * x
25		}
26		close(squares)
27	}()
28
29	// Printer (in main goroutine)
30	for x := range squares {
31		fmt.Println(x)
32	}
33}

pipeline3.go

单方向

 1package main
 2
 3import (
 4	"fmt"
 5)
 6
 7func counter(out chan<- int) {
 8	for x := 0; x < 100; x++ {
 9		out <- x
10	}
11	close(out)
12}
13
14func squarer(out chan<- int, in <-chan int) {
15	for v := range in {
16		out <- v * v
17	}
18	close(out)
19}
20
21func printer(in <-chan int) {
22	for v := range in {
23		fmt.Println(v)
24	}
25}
26
27func main() {
28	naturals := make(chan int)
29	squares := make(chan int)
30
31	// Counter
32	go counter(naturals)
33	// Squarer
34	go squarer(squares, naturals)
35	// Printer (in main goroutine)
36	printer(squares)
37}

并发爬虫

crawl1.go

无限制并发

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"os"
 8
 9	"golang.org/x/net/html"
10)
11
12func main() {
13	worklist := make(chan []string)
14	go func() { worklist <- os.Args[1:] }()
15	seen := make(map[string]bool)
16	for list := range worklist {
17		for _, link := range list {
18			if !seen[link] {
19				seen[link] = true
20				go func(link string) {
21					worklist <- crawl(link)
22				}(link)
23			}
24		}
25	}
26}
27
28func crawl(url string) []string {
29	fmt.Println(url)
30	list, err := Extract(url)
31	if err != nil {
32		log.Print(err)
33	}
34	return list
35}
36
37func Extract(url string) ([]string, error) {
38	resp, err := http.Get(url)
39	if err != nil {
40		return nil, err
41	}
42	if resp.StatusCode != http.StatusOK {
43		resp.Body.Close()
44		return nil, fmt.Errorf("getting %s: %s", url, resp.Status)
45	}
46	doc, err := html.Parse(resp.Body)
47	resp.Body.Close()
48	if err != nil {
49		return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
50	}
51	var links []string
52	visitNode := func(n *html.Node) {
53		if n.Type == html.ElementNode && n.Data == "a" {
54			for _, a := range n.Attr {
55				if a.Key != "href" {
56					continue
57				}
58				link, err := resp.Request.URL.Parse(a.Val)
59				if err != nil {
60					continue
61				}
62				links = append(links, link.String())
63			}
64		}
65	}
66	forEachNode(doc, visitNode, nil)
67	return links, nil
68}
69
70func forEachNode(n *html.Node, pre, post func(n *html.Node)) {
71	if pre != nil {
72		pre(n)
73	}
74	for c := n.FirstChild; c != nil; c = c.NextSibling {
75		forEachNode(c, pre, post)
76	}
77	if post != nil {
78		post(n)
79	}
80}

crawl2.go

将对 links.Extract 的调用操作用获取、释放 token 的操作包裹起来,来确保同一时间对其只有 20 个调用

 1package main
 2
 3import (
 4	"fmt"
 5	"log"
 6	"net/http"
 7	"os"
 8
 9	"golang.org/x/net/html"
10)
11
12func main() {
13	worklist := make(chan []string)
14	// number of pending sends to worklist
15	var n int
16
17	// start with the command-line arguments
18	n++
19	go func() { worklist <- os.Args[1:] }()
20
21	// crawl the web concurrently
22	seen := make(map[string]bool)
23
24	for ; n > 0; n-- {
25		list := <-worklist
26		for _, link := range list {
27			if !seen[link] {
28				seen[link] = true
29				n++
30				go func(link string) {
31					worklist <- crawl(link)
32				}(link)
33			}
34		}
35	}
36}
37
38var tokens = make(chan struct{}, 20)
39
40func crawl(url string) []string {
41	fmt.Println(url)
42	// acquire a token
43	tokens <- struct{}{}
44	list, err := Extract(url)
45	// release the token
46	<-tokens
47	if err != nil {
48		log.Print(err)
49	}
50	return list
51}
52
53func Extract(url string) ([]string, error) {
54	resp, err := http.Get(url)
55	if err != nil {
56		return nil, err
57	}
58	if resp.StatusCode != http.StatusOK {
59		resp.Body.Close()
60		return nil, fmt.Errorf("getting %s: %s", url, resp.Status)
61	}
62	doc, err := html.Parse(resp.Body)
63	resp.Body.Close()
64	if err != nil {
65		return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
66	}
67	var links []string
68	visitNode := func(n *html.Node) {
69		if n.Type == html.ElementNode && n.Data == "a" {
70			for _, a := range n.Attr {
71				if a.Key != "href" {
72					continue
73				}
74				link, err := resp.Request.URL.Parse(a.Val)
75				if err != nil {
76					continue
77				}
78				links = append(links, link.String())
79			}
80		}
81	}
82	forEachNode(doc, visitNode, nil)
83	return links, nil
84}
85
86func forEachNode(n *html.Node, pre, post func(n *html.Node)) {
87	if pre != nil {
88		pre(n)
89	}
90	for c := n.FirstChild; c != nil; c = c.NextSibling {
91		forEachNode(c, pre, post)
92	}
93	if post != nil {
94		post(n)
95	}
96}

基于 select 的多路复用

countdown1.go

 1package main
 2
 3import (
 4	"fmt"
 5	"time"
 6)
 7
 8func main() {
 9	fmt.Println("Commencing countdown.")
10	tick := time.Tick(1 * time.Second)
11	for countdown := 10; countdown > 0; countdown-- {
12		fmt.Println(countdown)
13		<-tick
14	}
15	launch()
16}
17
18func launch() {
19	fmt.Println("launch")
20}

countdown2.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6	"time"
 7)
 8
 9func main() {
10	abort := make(chan struct{})
11	go func() {
12		os.Stdin.Read(make([]byte, 1))
13		abort <- struct{}{}
14	}()
15
16	fmt.Println("Commencing countdown. Press return to abort.")
17	select {
18	case <-time.After(10 * time.Second):
19	case <-abort:
20		fmt.Println("Launch aborted!")
21		return
22	}
23	launch()
24}
25
26func launch() {
27	fmt.Println("launch")
28}

countdown3.go

 1package main
 2
 3import (
 4	"fmt"
 5	"os"
 6	"time"
 7)
 8
 9func main() {
10	abort := make(chan struct{})
11	go func() {
12		os.Stdin.Read(make([]byte, 1))
13		abort <- struct{}{}
14	}()
15
16	fmt.Println("Commencing countdown. Press return to abort.")
17	tick := time.Tick(time.Second)
18	for countdown := 10; countdown > 0; countdown-- {
19		fmt.Println(countdown)
20		select {
21		case <-tick:
22		case <-abort:
23			fmt.Println("Launch aborted!")
24			return
25		}
26	}
27	launch()
28}
29
30func launch() {
31	fmt.Println("launch")
32}

示例: 并发的目录遍历

du1.go

 1package main
 2
 3import (
 4	"flag"
 5	"fmt"
 6	"io/ioutil"
 7	"os"
 8	"path/filepath"
 9)
10
11func main() {
12	// Determine the initial directories
13	flag.Parse()
14	roots := flag.Args()
15	if len(roots) == 0 {
16		roots = []string{"."}
17	}
18
19	// Traverse the file tree
20	fileSizes := make(chan int64)
21	go func() {
22		for _, root := range roots {
23			walkDir(root, fileSizes)
24		}
25		close(fileSizes)
26	}()
27
28	// Print the results
29	var nfiles, nbytes int64
30	for size := range fileSizes {
31		nfiles++
32		nbytes += size
33	}
34	printDiskUsage(nfiles, nbytes)
35}
36
37func printDiskUsage(nfiles, nbytes int64) {
38	fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9)
39}
40
41func walkDir(dir string, fileSizes chan<- int64) {
42	for _, entry := range dirents(dir) {
43		if entry.IsDir() {
44			subdir := filepath.Join(dir, entry.Name())
45			walkDir(subdir, fileSizes)
46		} else {
47			fileSizes <- entry.Size()
48		}
49	}
50}
51
52func dirents(dir string) []os.FileInfo {
53	entries, err := ioutil.ReadDir(dir)
54	if err != nil {
55		fmt.Fprintf(os.Stderr, "du1: %v\n", err)
56		return nil
57	}
58	return entries
59}

du2.go

 1package main
 2
 3import (
 4	"flag"
 5	"fmt"
 6	"io/ioutil"
 7	"os"
 8	"path/filepath"
 9	"time"
10)
11
12var verbose = flag.Bool("v", false, "show verbose progress messages")
13
14func main() {
15	// Determine the initial directories
16	flag.Parse()
17	roots := flag.Args()
18	if len(roots) == 0 {
19		roots = []string{"."}
20	}
21
22	// Traverse the file tree
23	fileSizes := make(chan int64)
24	go func() {
25		for _, root := range roots {
26			walkDir(root, fileSizes)
27		}
28		close(fileSizes)
29	}()
30
31	var tick <-chan time.Time
32	if *verbose {
33		tick = time.Tick(500 * time.Millisecond)
34	}
35
36	// Print the results
37	var nfiles, nbytes int64
38
39loop:
40	for {
41		select {
42		case size, ok := <-fileSizes:
43			if !ok {
44				break loop
45			}
46			nfiles++
47			nbytes += size
48		case <-tick:
49			printDiskUsage(nfiles, nbytes)
50		}
51	}
52	printDiskUsage(nfiles, nbytes)
53}
54
55func printDiskUsage(nfiles, nbytes int64) {
56	fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9)
57}
58
59func walkDir(dir string, fileSizes chan<- int64) {
60	for _, entry := range dirents(dir) {
61		if entry.IsDir() {
62			subdir := filepath.Join(dir, entry.Name())
63			walkDir(subdir, fileSizes)
64		} else {
65			fileSizes <- entry.Size()
66		}
67	}
68}
69
70func dirents(dir string) []os.FileInfo {
71	entries, err := ioutil.ReadDir(dir)
72	if err != nil {
73		fmt.Fprintf(os.Stderr, "du1: %v\n", err)
74		return nil
75	}
76	return entries
77}

du3.go

 1package main
 2
 3import (
 4	"flag"
 5	"fmt"
 6	"io/ioutil"
 7	"os"
 8	"path/filepath"
 9	"sync"
10	"time"
11)
12
13var verbose = flag.Bool("v", false, "show verbose progress messages")
14
15func main() {
16	// Determine the initial directories
17	flag.Parse()
18	roots := flag.Args()
19	if len(roots) == 0 {
20		roots = []string{"."}
21	}
22
23	// Traverse each root of the file tree in parallel
24	fileSizes := make(chan int64)
25	var n sync.WaitGroup
26	for _, root := range roots {
27		n.Add(1)
28		go walkDir(root, &n, fileSizes)
29	}
30	go func() {
31		n.Wait()
32		close(fileSizes)
33	}()
34
35	var tick <-chan time.Time
36	if *verbose {
37		tick = time.Tick(500 * time.Millisecond)
38	}
39
40	// Print the results
41	var nfiles, nbytes int64
42
43loop:
44	for {
45		select {
46		case size, ok := <-fileSizes:
47			if !ok {
48				break loop
49			}
50			nfiles++
51			nbytes += size
52		case <-tick:
53			printDiskUsage(nfiles, nbytes)
54		}
55	}
56	printDiskUsage(nfiles, nbytes)
57}
58
59func printDiskUsage(nfiles, nbytes int64) {
60	fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9)
61}
62
63func walkDir(dir string, n *sync.WaitGroup, fileSizes chan<- int64) {
64	defer n.Done()
65	for _, entry := range dirents(dir) {
66		if entry.IsDir() {
67			n.Add(1)
68			subdir := filepath.Join(dir, entry.Name())
69			go walkDir(subdir, n, fileSizes)
70		} else {
71			fileSizes <- entry.Size()
72		}
73	}
74}
75
76// sema is a counting semaphore for limiting concurrency in dirents
77var sema = make(chan struct{}, 20)
78
79func dirents(dir string) []os.FileInfo {
80	sema <- struct{}{}        // acquire token
81	defer func() { <-sema }() // release token
82	entries, err := ioutil.ReadDir(dir)
83	if err != nil {
84		fmt.Fprintf(os.Stderr, "du1: %v\n", err)
85		return nil
86	}
87	return entries
88}

并发的退出

du4.go

  1package main
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"sync"
 10	"time"
 11)
 12
 13var verbose = flag.Bool("v", false, "show verbose progress messages")
 14
 15var done = make(chan struct{})
 16
 17func cancelled() bool {
 18	select {
 19	case <-done:
 20		return true
 21	default:
 22		return false
 23	}
 24}
 25
 26func main() {
 27	// Determine the initial directories
 28	flag.Parse()
 29	roots := flag.Args()
 30	if len(roots) == 0 {
 31		roots = []string{"."}
 32	}
 33
 34	// Cancel traversal when input is detected
 35	go func() {
 36		os.Stdin.Read(make([]byte, 1))
 37		close(done)
 38	}()
 39
 40	// Traverse each root of the file tree in parallel
 41	fileSizes := make(chan int64)
 42	var n sync.WaitGroup
 43	for _, root := range roots {
 44		n.Add(1)
 45		go walkDir(root, &n, fileSizes)
 46	}
 47	go func() {
 48		n.Wait()
 49		close(fileSizes)
 50	}()
 51
 52	var tick <-chan time.Time
 53	if *verbose {
 54		tick = time.Tick(500 * time.Millisecond)
 55	}
 56
 57	// Print the results
 58	var nfiles, nbytes int64
 59
 60loop:
 61	for {
 62		select {
 63		case <-done:
 64			// Drain fileSizes to allow existing goroutines to finish
 65			for range fileSizes {
 66				// Do nothing
 67			}
 68			return
 69		case size, ok := <-fileSizes:
 70			if !ok {
 71				break loop
 72			}
 73			nfiles++
 74			nbytes += size
 75		case <-tick:
 76			printDiskUsage(nfiles, nbytes)
 77		}
 78	}
 79	printDiskUsage(nfiles, nbytes)
 80}
 81
 82func printDiskUsage(nfiles, nbytes int64) {
 83	fmt.Printf("%d files %.1f GB\n", nfiles, float64(nbytes)/1e9)
 84}
 85
 86func walkDir(dir string, n *sync.WaitGroup, fileSizes chan<- int64) {
 87	defer n.Done()
 88	if cancelled() {
 89		return
 90	}
 91	for _, entry := range dirents(dir) {
 92		if entry.IsDir() {
 93			n.Add(1)
 94			subdir := filepath.Join(dir, entry.Name())
 95			go walkDir(subdir, n, fileSizes)
 96		} else {
 97			fileSizes <- entry.Size()
 98		}
 99	}
100}
101
102// sema is a counting semaphore for limiting concurrency in dirents
103var sema = make(chan struct{}, 20)
104
105func dirents(dir string) []os.FileInfo {
106	select {
107	case sema <- struct{}{}: // acquire token
108	case <-done:
109		return nil // cancelled
110	}
111	defer func() { <-sema }() // release token
112	entries, err := ioutil.ReadDir(dir)
113	if err != nil {
114		fmt.Fprintf(os.Stderr, "du1: %v\n", err)
115		return nil
116	}
117	return entries
118}

示例:聊天服务

chat.go

 1package main
 2
 3import (
 4	"bufio"
 5	"fmt"
 6	"log"
 7	"net"
 8)
 9
10func main() {
11	listener, err := net.Listen("tcp", "localhost:8000")
12	if err != nil {
13		log.Fatal(err)
14	}
15
16	go broadcaster()
17
18	for {
19		conn, err := listener.Accept()
20		if err != nil {
21			log.Print(err)
22			continue
23		}
24		go handleConn(conn)
25	}
26}
27
28type client chan<- string // an outgoing message channel
29
30var (
31	entering = make(chan client)
32	leaving  = make(chan client)
33	messages = make(chan string) // all incoming client messages
34)
35
36func broadcaster() {
37	clients := make(map[client]bool) // all connected clients
38	for {
39		select {
40		case msg := <-messages:
41			// Broadcast incoming message to all
42			// clients' outgoing message channels
43			for cli := range clients {
44				cli <- msg
45			}
46
47		case cli := <-entering:
48			clients[cli] = true
49
50		case cli := <-leaving:
51			delete(clients, cli)
52			close(cli)
53		}
54	}
55}
56
57func handleConn(conn net.Conn) {
58	ch := make(chan string) //outgoing client messages
59	go clientWriter(conn, ch)
60
61	who := conn.RemoteAddr().String()
62	ch <- "You are" + who
63	messages <- who + " has arrived"
64	entering <- ch
65
66	input := bufio.NewScanner(conn)
67	for input.Scan() {
68		messages <- who + ": " + input.Text()
69	}
70	// NOTE: ignoring potential errors from input.Err()
71
72	leaving <- ch
73	messages <- who + " has left"
74	conn.Close()
75}
76
77func clientWriter(conn net.Conn, ch <-chan string) {
78	for msg := range ch {
79		fmt.Fprintln(conn, msg) // NOTE: ignoring network errors
80	}
81}

竞争条件

bank1.go

balance 变量被限制在了 monitor goroutine 中

 1package bank
 2
 3var deposits = make(chan int) // send amout to deposit
 4var balances = make(chan int) // receive balance
 5
 6func Deposit(amount int) { deposits <- amount }
 7func Balance() int       { return <-balances }
 8
 9func teller() {
10	var balance int // balance is confined to teller goroutine
11	for {
12		select {
13		case amount := <-deposits:
14			balance += amount
15		case balances <- balance:
16		}
17	}
18}
19
20func init() {
21	go teller() // start the monitor goroutine
22}

bank2.go

我们可以用一个容量只有 1 的 channel 来保证最多只有一个 goroutine 在同一时刻访问一个共享变量。一个只能为 1 和 0 的信号量叫做二元信号量(binary semaphore)

 1package bank
 2
 3var (
 4	sema    = make(chan struct{}, 1) // a binary semaphore guarding balance
 5	balance int
 6)
 7
 8func Deposit(amount int) {
 9	sema <- struct{}{} // acquire token
10	balance = balance + amount
11	<-sema // release token
12}
13
14func Balance() int {
15	sema <- struct{}{} // acquire token
16	b := balance
17	<-sema // release token
18	return b
19}

bank3.go

sync.Mutex 互斥锁

 1package bank
 2
 3import "sync"
 4
 5var (
 6	mu      sync.RWMutex // guards balance
 7	balance int
 8)
 9
10func Deposit(amount int) {
11	mu.Lock()
12	defer mu.Unlock()
13	balance = balance + amount
14}
15
16func Balance() int {
17	mu.RLock() // readers lock
18	defer mu.RUnlock()
19	return balance
20}

示例: 并发的非阻塞缓存

memo1.go

 1package memo
 2
 3type Memo struct {
 4	f     Func
 5	cache map[string]result
 6}
 7
 8type Func func(key string) (interface{}, error)
 9
10type result struct {
11	value interface{}
12	err   error
13}
14
15func New(f Func) *Memo {
16	return &Memo{f: f, cache: make(map[string]result)}
17}
18
19// not concurrency-safe
20func (memo *Memo) Get(key string) (interface{}, error) {
21	res, ok := memo.cache[key]
22	if !ok {
23		res.value, res.err = memo.f(key)
24		mome.cache[key] = res
25	}
26	return res.value, res.err
27}

memo2.go

 1package memo
 2
 3import "sync"
 4
 5type Memo struct {
 6	f     Func
 7	mu    sync.Mutex // guards cache
 8	cache map[string]result
 9}
10
11type Func func(key string) (interface{}, error)
12
13type result struct {
14	value interface{}
15	err   error
16}
17
18func New(f Func) *Memo {
19	return &Memo{f: f, cache: make(map[string]result)}
20}
21
22// Get is concurrency-safe
23// 完全丧失了并发的性能优点
24// 每次对 f 的调用期间都会持有锁
25// Get 将本来可以并行运行的 I/O 操作串行化了
26func (memo *Memo) Get(key string) (interface{}, error) {
27	memo.mu.Lock()
28	res, ok := memo.cache[key]
29	if !ok {
30		res.value, res.err = memo.f(key)
31		mome.cache[key] = res
32	}
33	memo.mu.Unlock()
34	return res.value, res.err
35}

memo3.go

 1package memo
 2
 3import "sync"
 4
 5type Memo struct {
 6	f     Func
 7	mu    sync.Mutex // guards cache
 8	cache map[string]result
 9}
10
11type Func func(key string) (interface{}, error)
12
13type result struct {
14	value interface{}
15	err   error
16}
17
18func New(f Func) *Memo {
19	return &Memo{f: f, cache: make(map[string]result)}
20}
21
22// Get is concurrency-safe
23// 两次获取锁的中间阶段,其他 goroutine 可以随意使用 cache
24// 使性能得到提升
25// 问题:同一时刻请求相同 url 时会重复请求,存在多余的工作
26func (memo *Memo) Get(key string) (interface{}, error) {
27	memo.mu.Lock()
28	res, ok := memo.cache[key]
29	memo.mu.Unlock()
30	if !ok {
31		res.value, res.err = memo.f(key)
32
33		// Between the two critical sections, several goroutines
34		// may race to compute f(key) and update the map
35		memo.mu.Lock()
36		mome.cache[key] = res
37		memo.mu.Unlock()
38	}
39	return res.value, res.err
40}

memo4.go

实现使用了一个互斥量来保护多个 goroutine 调用 Get 时的共享 map 变量

 1package memo
 2
 3import "sync"
 4
 5type entry struct {
 6	res   result
 7	ready chan struct{} // closed when res is ready
 8}
 9
10type Memo struct {
11	f     Func
12	mu    sync.Mutex // guards cache
13	cache map[string]*entry
14}
15
16type Func func(key string) (interface{}, error)
17
18type result struct {
19	value interface{}
20	err   error
21}
22
23func New(f Func) *Memo {
24	return &Memo{f: f, cache: make(map[string]*entry)}
25}
26
27// Get is concurrency-safe
28func (memo *Memo) Get(key string) (interface{}, error) {
29	memo.mu.Lock()
30	e := memo.cache[key]
31	if e == nil {
32		// This is the first request for this key
33		// This goroutine becomes responsible for computing
34		// the value and broadcasting the ready condition
35		e = &entry{ready: make(chan struct{})}
36		memo.cache[key] = e
37		memo.mu.Unlock()
38
39		e.res.value, e.res.err = memo.f(key)
40
41		close(e.ready) // broadcast ready condition
42	} else {
43		// This is a repeat request for this key
44		memo.mu.Unlock()
45
46		<-e.ready // wait for ready condition
47	}
48	return e.res.value, e.res.err
49}

memo5.go

channel 通信的方式

 1package memo
 2
 3// A request is a message requesting that the Func be applied to key
 4type request struct {
 5	key      string
 6	response chan<- result // the client wants a single result
 7}
 8
 9type entry struct {
10	res   result
11	ready chan struct{} // closed when res is ready
12}
13
14type Memo struct {
15	requests chan request
16}
17
18type Func func(key string) (interface{}, error)
19
20type result struct {
21	value interface{}
22	err   error
23}
24
25// New returns a memoization of f
26// Clients must subsequently call Close
27func New(f Func) *Memo {
28	memo := &Memo{requests: make(chan request)}
29	go memo.server(f)
30	return memo
31}
32
33// Get is concurrency-safe
34func (memo *Memo) Get(key string) (interface{}, error) {
35	response := make(chan result)
36	memo.requests <- request{key, response}
37	res := <-response
38	return res.value, res.err
39}
40
41func (memo *Memo) Close() { close(memo.requests) }
42
43func (memo *Memo) server(f Func) {
44	cache := make(map[string]*entry)
45	for req := range mome.requests {
46		e := cache[req.key]
47		if e == nil {
48			// This is the first request for this key
49			e = &entry{ready: make(chan struct{})}
50			cache[req.key] = e
51			go e.call(f, req.key)
52		}
53		go e.deliver(req.response)
54	}
55}
56
57func (e *entry) call(f Func, key string) {
58	// Evaluate the function
59	e.res.value, e.res.err = f(key)
60	// Broadcast the ready condition
61	close(e.ready)
62}
63
64func (e *entry) deliver(response chan<- result) {
65	// Wait for the ready condition
66	<-e.ready
67	// Send the result to the client
68	response <- e.res
69}

测试

word.go

 1// Package word provides utilities for word games.
 2package word
 3
 4import "unicode"
 5
 6// IsPalindrome reports whether s reads the same forward and backward.
 7// Letter case is ignored, as are non-letters.
 8func IsPalindrome(s string) bool {
 9	// letters := make([]rune, 0)
10	letters := make([]rune, 0, len(s))
11	for _, r := range s {
12		if unicode.IsLetter(r) {
13			letters = append(letters, unicode.ToLower(r))
14		}
15	}
16	n := len(letters) / 2
17	for i := 0; i < n; i++ {
18		if letters[i] != letters[len(letters)-1-i] {
19			return false
20		}
21	}
22	return true
23}

word_test.go

 1package word
 2
 3import "testing"
 4
 5func TestIsPalindrome(t *testing.T) {
 6	var tests = []struct {
 7		input string
 8		want  bool
 9	}{
10		{"", true},
11		{"a", true},
12		{"aa", true},
13		{"ab", false},
14		{"kayak", true},
15		{"detartrated", true},
16		{"A man, a plan, a canal: Panama", true},
17		{"Evil I did dwell; lewd did I live.", true},
18		{"Able was I ere I saw Elba", true},
19		{"été", true},
20		{"Et se resservir, ivresse reste.", true},
21		{"palindrome", false}, // non-palindrome
22		{"desserts", false},   // semi-palindrome
23	}
24	for _, test := range tests {
25		if got := IsPalindrome(test.input); got != test.want {
26			t.Errorf("IsPalindrome(%q) = %v", test.input, got)
27		}
28	}
29}
30
31func BenchmarkIsPalindrome(b *testing.B) {
32	for i := 0; i < b.N; i++ {
33		IsPalindrome("A man, a plan, a canal: Panama")
34	}
35}