Hi folks!
Been struggling with this problem for a while so I figured I’d solicit suggestions here:
I have created a model architecture similar to AlphaFold2 where the input is very heterogeneous in nature and each input type has a series of transformations before becoming one data “stack” (e.g. 5x1000 tensor) that gets passed through a shallow resnet for the classification task.
The largest structural issue that I’m facing is that one of the input nodes could be anywhere from 1 channel (e.g. shape 1x1000) to 8 channels (e.g. shape 8x1000) at any point in the dataloader. This is largely fine until I need to eventually encode that structure into a single-channel embedding to put it on the pre-resnet data stack.
The things that I’ve looked at so far:
I could just average them all into one channel (problem: the order of those channels matters quite a bit and it feels like the data lost there would be immense). I could create like 8 different subpaths in the model (problem: not enough training data for correctly training most of the subpaths - 1 channel path would be more heavily trained than the 8 channel path). Do PCA on the transposed vector with n_components=1 and re-transpose the vector (problem: just feels dumb - not sure if that’s a legitimate thought).
Any other suggestions? Or are there common practices here that I’m just unaware of?
If the order matters it is pretty common to add positional encodings like they do for transformers
Use a transformer layer for aggregation if you want a learnable way of pooling them. Positional encoding and masking should help you with ensuring that order influences the prediction.
Thanks - that’s where I had started leaning, but wanted to be sure. And just to be clear, I’d functionally need to “feed through” the data through the transformer in a tokenized manner since the shape of the input vector is variable? So basically split the input vector to the layer into chunks with their indexes as the queries in the attention layer. And in the forward pass just loop through the input vector until I’m done. u/Green_ninjas, u/pm_me_your_pay_slips