/* See LICENSE file for copyright and license details. */ package unxml import ( "fmt" "io" "os" "golang.org/x/net/html" ) type Reader struct { reader io.Reader tagmap map[string]bool lastread []byte count int tokenizer *html.Tokenizer } type ElementReader struct { xr Reader tagsinstack map[string]int intagtokeep bool } //type stack []string // //func (s stack) Empty() bool { return len(s) == 0 } //func (s stack) Peek() string { return s[len(s)-1] } //func (s *stack) Pop() string { // d := (*s)[len(*s)-1] // (*s) = (*s)[:len(*s)-1] // return d //} func NewReaderKeepTags(r io.Reader, tagstokeep []string) *Reader { var tagmap map[string]bool if len(tagstokeep) > 0 { tagmap = make(map[string]bool, 10) for _, tag := range tagstokeep { tagmap[tag] = true } } return &Reader{reader: r, tagmap: tagmap, tokenizer: html.NewTokenizer(r), } } func NewReaderKeepElements(r io.Reader, tagstokeep []string) *ElementReader { var tagmap map[string]bool if len(tagstokeep) > 0 { tagmap = make(map[string]bool, 10) for _, tag := range tagstokeep { tagmap[tag] = true } } return &ElementReader{ xr: Reader{ tagmap: tagmap, tokenizer: html.NewTokenizer(r), }, tagsinstack: make(map[string]int, 5), } } func (r *Reader) Read(out []byte) (int, error) { var err error r.count = 0 n := 0 lenout := len(out) if lenout == 0 { return r.count, nil } lenlr := len(r.lastread) if lenlr > 0 { n = copy(out, r.lastread) r.count += n if n < lenlr { r.lastread = r.lastread[n:] return r.count, nil } r.lastread = make([]byte, lenout) } for { tt := r.tokenizer.Next() switch tt { case html.ErrorToken: return r.count, io.EOF case html.TextToken: text := r.tokenizer.Text() n = copy(out[r.count:], text) r.count += n if n < len(text) { r.lastread = text[n:] return r.count, err } case html.StartTagToken: raw := r.tokenizer.Raw() //fmt.Fprintf(os.Stderr, "RawToken: %s\n", raw) tn, _ := r.tokenizer.TagName() //fmt.Fprintf(os.Stderr, "TagName: %s\n", tn) if _, ok := r.tagmap[string(tn)]; ok { n := copy(out[r.count:], raw) r.count += n if n < len(raw) { r.lastread = raw[n:] return r.count, nil } } case html.EndTagToken: raw := r.tokenizer.Raw() tn, _ := r.tokenizer.TagName() if _, ok := r.tagmap[string(tn)]; ok { n := copy(out[r.count:], raw) r.count += n if n < len(raw) { r.lastread = raw[n:] return r.count, nil } } } } } func (r *ElementReader) Read(out []byte) (int, error) { var err error r.xr.count = 0 n := 0 lenout := len(out) if lenout == 0 { return r.xr.count, nil } lenlr := len(r.xr.lastread) if lenlr > 0 { //fmt.Printf("Using lastread: %q\n", r.xr.lastread) n = copy(out, r.xr.lastread) r.xr.count += n if n < lenlr { //fmt.Printf("Using lastread not enough: %q\n", r.xr.lastread) r.xr.lastread = r.xr.lastread[n:] return r.xr.count, err } r.xr.lastread = make([]byte, 0, lenout) } for { tt := r.xr.tokenizer.Next() switch tt { case html.ErrorToken: fmt.Fprintf(os.Stderr, "There was an error when parsing the html: %s, %s\n", tt, r.xr.tokenizer.Err()) return r.xr.count, io.EOF case html.TextToken: if !r.intagtokeep { continue } text := r.xr.tokenizer.Text() //fmt.Printf("HAD SPACE: %q, count: %d\n", text, r.xr.count) n = copy(out[r.xr.count:], text) r.xr.count += n if n < len(text) { //fmt.Printf("HAD NO SPACE: wrote: %q, count: %d, n: %d\n", text[:n], r.xr.count, n) r.xr.lastread = text[n:] //fmt.Printf("lastread is now: %q\n", text[n:]) return r.xr.count, err } case html.StartTagToken: tn, _ := r.xr.tokenizer.TagName() //fmt.Printf("TagNameStart: %s\n", tn) if _, ok := r.xr.tagmap[string(tn)]; ok { r.tagsinstack[string(tn)]++ r.intagtokeep = true raw := r.xr.tokenizer.Raw() //fmt.Printf("TokenRaw: %s\n", raw) n := copy(out[r.xr.count:], raw) r.xr.count += n if n < len(raw) { r.xr.lastread = raw[n:] return r.xr.count, err } } else { r.intagtokeep = false } case html.EndTagToken: tn, _ := r.xr.tokenizer.TagName() //fmt.Printf("TagEndName: %s\n", tn) if count, ok := r.tagsinstack[string(tn)]; ok { //fmt.Printf("TagEndNameInStack: %s, %d\n", tn, count) if count == 1 { delete(r.tagsinstack, string(tn)) r.intagtokeep = false } else { r.tagsinstack[string(tn)]-- } raw := r.xr.tokenizer.Raw() n := copy(out[r.xr.count:], raw) r.xr.count += n if n < len(raw) { r.xr.lastread = raw[n:] return r.xr.count, err } } } } }