class CrossAttention(nn.Module):
def init(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().init()
Hi, I am having some trouble setting the dimensions for modules in my code.
Could you please tell me how should I set "query_dim" and "context_dim" if I were to use this function to calculate the cross attention between a feature map X and a text?
Is "query_dim" is the dimension of the feature map?
Is "context_dim" the dimension of the tensor I get after feeding a text into an text encoder?
Thank you so much!