Skip to content

Commit

Permalink
Update streamPrediction to do non-blocking send to errChan (#84)
Browse files Browse the repository at this point in the history
* Non-blocking send to errChan

* Create sendError helper method
  • Loading branch information
mattt authored Sep 23, 2024
1 parent 3716050 commit 5c48afb
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,20 @@ func (e *SSEEvent) String() string {
}
}

func (r *Client) sendError(err error, errChan chan error) {
select {
case errChan <- err:
default:
}
}

func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) {
sseChan := make(chan SSEEvent, 64)
errChan := make(chan error, 64)

id, err := ParseIdentifier(identifier)
if err != nil {
errChan <- err
r.sendError(err, errChan)
return sseChan, errChan
}

Expand All @@ -115,7 +122,7 @@ func (r *Client) Stream(ctx context.Context, identifier string, input Prediction
}

if err != nil {
errChan <- err
r.sendError(err, errChan)
return sseChan, errChan
}

Expand All @@ -136,16 +143,13 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
url := prediction.URLs["stream"]
if url == "" {
errChan <- errors.New("streaming not supported or not enabled for this prediction")
r.sendError(errors.New("streaming not supported or not enabled for this prediction"), errChan)
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
select {
case errChan <- fmt.Errorf("failed to create request: %w", err):
default:
}
r.sendError(fmt.Errorf("failed to create request: %w", err), errChan)
return
}
req.Header.Set("Accept", "text/event-stream")
Expand All @@ -163,18 +167,12 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
} else {
defer resp.Body.Close()
}
select {
case errChan <- fmt.Errorf("failed to send request: %w", err):
default:
}
r.sendError(fmt.Errorf("failed to send request: %w", err), errChan)
return
}

if resp.StatusCode != http.StatusOK {
select {
case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode):
default:
}
r.sendError(fmt.Errorf("received invalid status code: %d", resp.StatusCode), errChan)
return
}

Expand Down Expand Up @@ -229,10 +227,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l

event, err := decodeSSEEvent(b)
if err != nil {
select {
case errChan <- err:
default:
}
r.sendError(err, errChan)
continue
}

Expand Down Expand Up @@ -269,10 +264,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
}

if !errors.Is(err, context.Canceled) {
select {
case errChan <- err:
default:
}
r.sendError(err, errChan)
}
}

Expand Down

0 comments on commit 5c48afb

Please sign in to comment.