|
8 | 8 | "encoding/json" |
9 | 9 | "errors" |
10 | 10 | "fmt" |
| 11 | + "io" |
11 | 12 | "os" |
12 | 13 | "os/signal" |
13 | 14 | "regexp" |
@@ -1645,6 +1646,78 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot |
1645 | 1646 | return nil |
1646 | 1647 | } |
1647 | 1648 |
|
| 1649 | +func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error { |
| 1650 | + rpc := "WriteStateBytes" |
| 1651 | + ctx := srv.Context() |
| 1652 | + ctx = s.loggingContext(ctx) |
| 1653 | + ctx = logging.RpcContext(ctx, rpc) |
| 1654 | + // ctx = logging.StateStoreContext(ctx, protoReq.TypeName) |
| 1655 | + ctx = s.stoppableContext(ctx) |
| 1656 | + // logging.ProtocolTrace(ctx, "Received request") |
| 1657 | + // defer logging.ProtocolTrace(ctx, "Served request") |
| 1658 | + |
| 1659 | + ctx = tf6serverlogging.DownstreamRequest(ctx) |
| 1660 | + |
| 1661 | + server, ok := s.downstream.(tfprotov6.StateStoreServer) |
| 1662 | + if !ok { |
| 1663 | + err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes") |
| 1664 | + logging.ProtocolError(ctx, err.Error()) |
| 1665 | + return err |
| 1666 | + } |
| 1667 | + |
| 1668 | + iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) { |
| 1669 | + for { |
| 1670 | + chunk, err := srv.Recv() |
| 1671 | + if err == io.EOF { |
| 1672 | + break |
| 1673 | + } |
| 1674 | + if err != nil { |
| 1675 | + // attempt to send the error back to client |
| 1676 | + msgErr := srv.SendMsg(&tfplugin6.WriteStateBytes_Response{ |
| 1677 | + Diagnostics: toproto.Diagnostics([]*tfprotov6.Diagnostic{ |
| 1678 | + { |
| 1679 | + Severity: tfprotov6.DiagnosticSeverityError, |
| 1680 | + Summary: "Writing state chunk failed", |
| 1681 | + Detail: fmt.Sprintf("Attempt to write a byte chunk of state %q to %q failed: %s", |
| 1682 | + chunk.StateId, chunk.TypeName, err), |
| 1683 | + }, |
| 1684 | + }), |
| 1685 | + }) |
| 1686 | + if msgErr != nil { |
| 1687 | + err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes") |
| 1688 | + logging.ProtocolError(ctx, err.Error()) |
| 1689 | + return |
| 1690 | + } |
| 1691 | + return |
| 1692 | + } |
| 1693 | + |
| 1694 | + ok := yield(tfprotov6.WriteStateByteChunk{ |
| 1695 | + Bytes: chunk.Bytes, |
| 1696 | + TotalLength: chunk.TotalLength, |
| 1697 | + Range: tfprotov6.StateByteRange{ |
| 1698 | + Start: chunk.Range.Start, |
| 1699 | + End: chunk.Range.End, |
| 1700 | + }, |
| 1701 | + }) |
| 1702 | + if !ok { |
| 1703 | + return |
| 1704 | + } |
| 1705 | + |
| 1706 | + } |
| 1707 | + } |
| 1708 | + |
| 1709 | + resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{ |
| 1710 | + Chunks: iterator, |
| 1711 | + }) |
| 1712 | + if err != nil { |
| 1713 | + return err |
| 1714 | + } |
| 1715 | + |
| 1716 | + return srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{ |
| 1717 | + Diagnostics: toproto.Diagnostics(resp.Diagnostics), |
| 1718 | + }) |
| 1719 | +} |
| 1720 | + |
1648 | 1721 | func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) { |
1649 | 1722 | rpc := "GetStates" |
1650 | 1723 | ctx = s.loggingContext(ctx) |
|
0 commit comments