| 
8 | 8 | 	"encoding/json"  | 
9 | 9 | 	"errors"  | 
10 | 10 | 	"fmt"  | 
 | 11 | +	"io"  | 
11 | 12 | 	"os"  | 
12 | 13 | 	"os/signal"  | 
13 | 14 | 	"regexp"  | 
@@ -1645,6 +1646,69 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot  | 
1645 | 1646 | 	return nil  | 
1646 | 1647 | }  | 
1647 | 1648 | 
 
  | 
 | 1649 | +func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error {  | 
 | 1650 | +	rpc := "WriteStateBytes"  | 
 | 1651 | +	ctx := srv.Context()  | 
 | 1652 | +	ctx = s.loggingContext(ctx)  | 
 | 1653 | +	ctx = logging.RpcContext(ctx, rpc)  | 
 | 1654 | +	// ctx = logging.StateStoreContext(ctx, protoReq.TypeName)  | 
 | 1655 | +	ctx = s.stoppableContext(ctx)  | 
 | 1656 | +	// logging.ProtocolTrace(ctx, "Received request")  | 
 | 1657 | +	// defer logging.ProtocolTrace(ctx, "Served request")  | 
 | 1658 | + | 
 | 1659 | +	ctx = tf6serverlogging.DownstreamRequest(ctx)  | 
 | 1660 | + | 
 | 1661 | +	server, ok := s.downstream.(tfprotov6.StateStoreServer)  | 
 | 1662 | +	if !ok {  | 
 | 1663 | +		err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")  | 
 | 1664 | +		logging.ProtocolError(ctx, err.Error())  | 
 | 1665 | +		return err  | 
 | 1666 | +	}  | 
 | 1667 | + | 
 | 1668 | +	var iteratorErr error  | 
 | 1669 | + | 
 | 1670 | +	// TODO: what about error handling per chunk and providers having the ability to do cleanup on interruption?  | 
 | 1671 | + | 
 | 1672 | +	iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) {  | 
 | 1673 | +		for {  | 
 | 1674 | +			chunk, err := srv.Recv()  | 
 | 1675 | +			if err == io.EOF {  | 
 | 1676 | +				break  | 
 | 1677 | +			}  | 
 | 1678 | +			if err != nil {  | 
 | 1679 | +				iteratorErr = err  | 
 | 1680 | +				srv.SendMsg(&tfplugin6.WriteStateBytes_Response{  | 
 | 1681 | +					// Diagnostics: ,  | 
 | 1682 | +				})  | 
 | 1683 | +				return  | 
 | 1684 | +			}  | 
 | 1685 | + | 
 | 1686 | +			yield(tfprotov6.WriteStateByteChunk{  | 
 | 1687 | +				Bytes:       chunk.Bytes,  | 
 | 1688 | +				TotalLength: chunk.TotalLength,  | 
 | 1689 | +				Range: tfprotov6.StateByteRange{  | 
 | 1690 | +					Start: chunk.Range.Start,  | 
 | 1691 | +					End:   chunk.Range.End,  | 
 | 1692 | +				},  | 
 | 1693 | +			})  | 
 | 1694 | + | 
 | 1695 | +		}  | 
 | 1696 | +	}  | 
 | 1697 | + | 
 | 1698 | +	resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{  | 
 | 1699 | +		Chunks: iterator,  | 
 | 1700 | +	})  | 
 | 1701 | +	if err != nil {  | 
 | 1702 | +		return err  | 
 | 1703 | +	}  | 
 | 1704 | + | 
 | 1705 | +	err = srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{  | 
 | 1706 | +		// Diagnostics: resp.Diagnostics,  | 
 | 1707 | +	})  | 
 | 1708 | + | 
 | 1709 | +	return nil  | 
 | 1710 | +}  | 
 | 1711 | + | 
1648 | 1712 | func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {  | 
1649 | 1713 | 	rpc := "GetStates"  | 
1650 | 1714 | 	ctx = s.loggingContext(ctx)  | 
 | 
0 commit comments