client端
package main
import (
"log"
"net/url"
"os"
"os/signal"
"time"
"github.com/gorilla/websocket"
)
func main() {
if len(os.Args) != 2 {
log.Println("invalid args: client host:port")
return
}
addr := os.Args[1]
log.SetFlags(0)
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
u := url.URL{Scheme: "ws", Host: addr, Path: "/echo"}
log.Printf("connecting to %s", u.String())
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
log.Fatal("dial:", err)
}
defer c.Close()
log.Println("dial finished")
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
return
}
log.Printf("recv: %s", message)
}
}()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-done:
return
case t := <-ticker.C:
log.Printf("send: %s", t.String())
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
if err != nil {
log.Println("write:", err)
return
}
case <-interrupt:
log.Println("interrupt")
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("write close:", err)
return
}
select {
case <-done:
case <-time.After(time.Second):
}
return
}
}
}
server端,指定一个监听端口
package main
import (
"context"
"flag"
"html/template"
"log"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/websocket"
)
var addr = flag.String("addr", "0.0.0.0:8077", "http service address")
var upgrader = websocket.HertzUpgrader{
CheckOrigin: func(ctx *app.RequestContext) bool {
return true
},
}
func echo(_ context.Context, c *app.RequestContext) {
log.Println("callback echo")
err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
for {
mt, message, err := conn.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("recv: %s", message)
err = conn.WriteMessage(mt, message)
if err != nil {
log.Println("write:", err)
break
}
}
})
if err != nil {
log.Print("upgrade:", err)
return
}
}
func home(_ context.Context, c *app.RequestContext) {
log.Println("callback home")
c.SetContentType("text/html; charset=utf-8")
homeTemplate.Execute(c, "ws://"+string(c.Host())+"/echo")
}
func main() {
flag.Parse()
h := server.Default(server.WithHostPorts(*addr))
h.NoHijackConnPool = true
h.GET("/", home)
h.GET("/echo", echo)
log.Println("init finished")
h.Spin()
}
var homeTemplate = template.Must(template.New("").Parse(`
<!DOCTYPE html>
<html>
<head>ws head</head>
<body>ws body</body>
</html>
`))