Dataset-Induced Meta-Learning (and other tricks): Improving model efficiency on ARC
In this post, we want to give a more expanded explanation of our approach and an update on recent lessons we have learned that may be useful to others in tackling this difficult challenge. We’ve adopted an empirical, machine learning-based approach to the ARC Challenge, fine-tuning large language models on a blend of synthetic ARC riddles, augmented ARC data, and NLP tasks. Our efforts have been bolstered by computational resources from Google’s TPU Research Cloud. Our best model has shown steady progress, solving 13 items on the private test set, and seems to be gaining one correct item every week or two.
📅 October 9, 2023
We’ve optimized our training systems to use a diverse range of datasets, focusing on meta-learning and broad capabilities. Pre-fine tuning on either NLP or ARC datasets has proven to accelerate learning over training on ARC data alone or combined ARC and NLP data from the start. We also experimented with eval-time strategies including ensembling, test-time fine-tuning, and a Correct Answer Collection System (CACS). Test-time fine-tuning is a key innovation that takes our current best model from 4-5 items correct without TTFT to 13 correct with TTFT.
Although we started out training models on our own PCs, it became clear that this would be insufficient for the task. We were able to apply for and receive credit to use TPUs from Google TPU Research Cloud (TRC). Although the quantity and size of the TPUs we are using is insufficient to train anything over 1B parameters, we have been able to run numerous experiments and train smaller models. If you apply for TPU access through the TRC, we recommend sending them a follow-up email to request more time. We requested access to V2 and V3 TPU’s. We are very thankful to Google and the TRC for allowing us to use this resource.
Our philosophy is that if ARC is a measure of intelligence, then you want to train your model to be as generally intelligent as possible. This means that a model needs to be able to solve a wide range of tasks beyond ARC, and it must be able to pull information it has learned in other contexts into the process of solving ARC riddles. To select NLP datasets to boost model performance on ARC, we have paid close attention to recent papers on model releases and datasets to help select the datasets that have the highest impact. We’ll talk some more about this in the section labeled Dataset-Induced Meta Learning below.
It has been a long road to developing the trainer that we are currently using. In essence, it loads and samples from the many generated datasets we are using. Additionally, it mixes in the HuggingFace NLP task training that we have added. We are training using JAX with Flax-based models. The training system defines a training epoch as a fixed number of training samples that are randomly sampled from the full set of datasets. We keep track of the items that are used so that they are not sampled again until the end of all datasets is reached, or training is started over. Our trainer detects when new datasets are added and automatically adds the new training data without stopping training. We also train the models to allow for instruction following and chat-based interactions. We add many variations of system, user, and response prefixes so that the model can handle the variations. We are trying to make as versatile and useful small models as the state of the art allows, and hope that the training on algorithmic learning and ARC will result in better models across a range of tasks.
GPT-4 has been useful in generating code to generate synthetic ARC riddles. Whereas GPT-4 has some capacity for solving a limited number of ARC riddles, it performs much better in generating code to generate synthetic riddles. You can learn more about our method by reading a chat with GPT-4 Advanced Data Analysis, wherein we create riddles based on the concept of gravity. GPT-4 has been useful in solving numerous programming challenges and is helpful in providing new ideas and new variations for training data.
Approaches to Meta-Learning
Dataset-Induced Meta Learning
Not all model pre-training is created equally when it comes to a model’s performance on ARC. This applies not only to how well a model can ultimately perform on ARC, but also how quickly a model can learn ARC. We discovered that fine tuning a model with our selection of NLP datasets, before adding ARC data, results in dramatic improvements in the evaluation loss curve and performance on ARC. We coined the phrase Dataset-Induced Meta Learning to explain the phenomena of models being able to learn new tasks more quickly based on pre-training data or fine-tuning data.
Figure 1: NLP Pre-Fine-Tuning Results
Note: The loss value above is the prediction loss during evaluation on the public test set. The training was performed on 1.6M training examples and switched to all datasets for another 1.6M training examples. For the All Datasets and ARC Only condition, the data was kept the same throughout, which resulted in a very flat loss curve. PCFG refers to Probabilistic Context-Free Grammar. In our case, we generated a PCFG dataset of 100 different algorithms and formatted the items in several different ways, including few-shot examples similar to ARC. Mirchandani et al. (2023) provide a strong rationale for the use of PCFG in ARC. Simple MM refers to simple multimodal tasks. We developed a simple multimodal dataset that involves converting to and from base64/png images to ARC-like boards. The complex multimodal tasks involve our own more complex multimodal tasks that involve translating natural language to ARC-style boards to Python code and to base64/png images. Additionally, we include several multimodal datasets from the Multimodal Instruction Dataset (primarily the reasoning datasets).
As you can see from the chart, training with only ARC data, or a combination of the ARC and NLP data from the beginning, results in a very flat loss curve when training a Flan T5 (Chung et al., 2022) or Long T5 model (Guo et al., 2022). Code T5 (Wang et al., 2021) from Salesforce performs better, which also supports our point about the importance of model pre-training/pre-fine tuning. Even though the Salesforce model learns ARC more quickly than the baseline Flan/Long T5 models, it also benefits significantly from our NLP pre-fine tuning. Although model performance eventually improves when using the ARC only — or ARC and NLP data from the start — you can dramatically cut training time by either pre-fine tuning on NLP data and then adding ARC (or pre-fine tuning on ARC) and then adding NLP data. It works both ways, but it seems to be a hill that is very difficult to climb without introducing one or the other after pre-fine tuning. It is this learning acceleration that we call dataset-induced meta-learning, which we think will be an important concept to explore in future research.
Meta-Learning through Large Concept Datasets
One concept we have employed to induce a larger number of priors into the model, as well as meta-learning, is what we call the cross-item riddle. By cross-item, we mean that you must learn underlying concepts across a large number of other riddles in order to solve a single riddle. One example of this would be our cellular automata-inspired riddles (see Figure 2 below for an example). In this dataset, we use several different types of cellular automata and many different rules. For the input boards, we choose a random number of steps to start. The output boards are created by advancing from the input boards by a fixed number of steps within each riddle. To solve these riddles, the model must learn the underlying concepts about cellular automata. In experiments, this approach seemed to result in models that can learn and be fine-tuned on a smaller amount of training data. This allows for the use of smaller datasets created based on the evaluation dataset for test-time fine-tuning.
We have not rigorously explored this empirically yet, but we do see a clear benefit in these large-scale concepts. Another large-scale concept we have used is based on generated boards with a wide variety of mathematical equations (see Figure 4). We use a similar concept to the CA riddles by advancing a fixed number of steps from the input to output boards. With both large-scale concepts, we have a variety of repair riddles, wherein the model must fill in the missing parts. Since there are a wide variety of shapes and patterns that are generated by these riddles, the repair type of riddle induces many priors into the model including symmetries, shapes, progression, and counting, to name a few.
Figure 4: Examples of Area Repair Riddles with Boards Generated by Math Equations
Maintaining In-Distribution Training
With the use of so many out-of-distribution datasets, one could predict that a resulting model might end up being out-of-distribution for ARC. We have met this challenge by watching the training loss and model performance when adding new data. In response to a model becoming more out-of-distribution, we add more data that is in-distribution for ARC, namely, augmented items from the ARC training and evaluation datasets. This pulls the model back towards ARC–reducing eval loss and improving performance on ARC.
Discussion on Multi-task Learning & Its Implications
We have explored, and are exploring, various approaches to ensembling. In the past, we explored k-fold ensembling (Kohavi, 1995). We trained 6 models, with each leaving out 1/6 of the training data (a different ⅙ is left out for each model). Then, we combined the results through voting. Ultimately, we didn’t get a higher number of total items correct, but the results were more reliable (e.g. ensembling the results as training continued by epoch resulted in much less variation from epoch to epoch). This approach, while interesting, is ultimately impractical. Currently, we have a system ready that will ensemble two or three different models, with each model providing one or more answers for the final submission.
Given that we have saved many model checkpoints, and different models have significant differences in the items they are able to solve, ensembling makes sense. We can perform test-time fine-tuning on multiple models and utilize a specified number of answers from each model for the final submission. Additionally, the same system will work for ensembling other current ARC solutions that are open source to make a super-ensemble. We utilized an existing ensemble from Kaggle (last edited by Michael Hodel) as inspiration to make a script that runs some or all existing solutions in parallel, which greatly improves the runtime. This is important given runtime limits for the GPU with the official challenge. We were able to get 29 correct out of the 100-item private test set without adding our model to the ensemble with a runtime of 2 hours (also leaving open the second and third answer for many items that could be filled in from our model). We were able to get 52% correct on the 100-item public test set (104 items when counting those with more than one test pair) with the full ensemble, including our model, but have not yet submitted it to Lab42’s ARCathon to run it against the private test set.
Test-Time Fine Tuning (TTFT)
Test time fine-tuning is another technique we have found success with for improving the performance of our models. During test time, we can use the given “training pairs” in the test set riddles and augment them to make synthetic training riddles. We can then use these synthetic riddles to perform a short fine-tuning run before proceeding to make predictions on the test set.
The idea is that the model will make mistakes in the private test set, and we can use that feedback to our advantage. The model may change and iterate on possible solutions (possible transformations, given the error feedback), improving the chances of success. The test time fine-tuning could also help the model correctly produce the exact right answer, as sometimes the model may be close to the right answer (correct transformation idea) but will be off by a single pixel or two. This slight inaccuracy probably happens due to the smaller size of the model and possibly imperfect training, but the test time fine-tuning may help to correct this and specialize to these given riddles and transformations a little bit more to be able to execute them perfectly.
This entire paradigm of test-time fine-tuning is of course, a meta-learning-based idea and goes back to the idea of meta-learning being the essential part of ARC. We as humans test out solutions in our minds and see if they work on the train pairs or not, iterating on possible solutions before attempting to solve the test board, and that process is essential to how we as humans are able to solve the riddles. In fact, the way that we create the TTFT dataset, the model is trained to try to get every train pair correct as the test pair. We don’t yet have the full picture of why TTFT is so beneficial, but we plan to continue exploring this important issue.
Correct Answer Collection System (CACS)
Finally, we have developed what we call a Correct Answer Collection System (CACS). The CACS makes a new test dataset from the existing test dataset items by dropping the existing test item(s) because there is no answer available during the formal challenge. Then, one of the train pairs is made into the test pair. This allows us to see if the solution can solve the item. This solution is imperfect because there will be some items that require all train items to be fully present. However, this is infrequent enough that the method is still useful.
- Computational Resources: TPUs are helpful for efficient experimentation with smaller LLMs.
- Pre-Fine Tuning and Model Selection: Strategic NLP dataset selection can improve training efficiency and model performance. Also, selecting models with code pre-training has provided the best results (although Long T5 achieves lower overall prediction loss).
- Meta-Learning: Large-concept datasets, cross-item riddles, data augmentation, and high-quality NLP tasks show promise for enhanced meta-learning.
- Model Monitoring: Ongoing performance tracking is vital for keeping the model in-distribution for ARC. Adding augmented ARC riddles can help keep a model training in-distribution.
- Evaluation-Time Strategies: Test-time fine-tuning is critical for improving model performance. Also, one most likely cannot win ARC without ensembling this year. We have developed tools and methods for ensembling across multiple models and other approaches, partly inspired by existing ensembles and new approaches, such as the Correct-Answer Collection System (CACS). We don’t find ensembling to be a satisfying solution, but rather a pragmatic reality.
- Future Work: We aim to refine our ensembling techniques and delve deeper into meta-learning strategies. The best approach may ultimately be one that includes minimal or no ARC-related synthetic data. An empirical/experimental approach could lead to increasingly capable general models that perform well on ARC as well as a broad range of challenging tasks. Whether this is possible with current architectures is an open question. What is clear from our work and from others (e.g., Li et al., 2023) is that the upper limits of the capabilities of LLMs, even small ones, have not yet been discovered, and there is clear evidence that more generally capable models are possible.
Authors Contact Information
Chung, H. W., Hou, L., Longpre, S., Zoph, B., Tay, Y., Fedus, W., Li, Y., Wang, X., Dehghani, M., Brahma, S., Webson, A., Gu, S. S., Dai, Z., Suzgun, M., Chen, X., Chowdhery, A., Castro-Ros, A., Pellat, M., Robinson, K., Valter, D., Narang, S., Mishra, G., Yu, A., Zhao, V., Huang, Y., Dai, A., Yu, H., Petrov, S., Chi, E. H., Dean, J., Devlin, J., Roberts, A., Zhou, D., Le, Q. V., & Wei, J. (2022). Scaling Instruction-Finetuned Language Models. arXiv:2210.11416 [cs.LG]. https://doi.org/10.48550/arXiv.2210.11416
Guo, M., Ainslie, J., Uthus, D., Ontanon, S., Ni, J., Sung, Y.-H., & Yang, Y. (2022). LongT5: Efficient Text-To-Text Transformer for Long Sequences. In Proceedings of the NAACL 2022. arXiv:2112.07916 [cs.CL]. https://doi.org/10.48550/arXiv.2112.07916
Kohavi, R. (1995). A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection. In Proceedings of the International Joint Conference on Artificial Intelligence (IJCAI). Computer Science Department, Stanford University.
Li, Y., Bubeck, S., Eldan, R., Del Giorno, A., Gunasekar, S., & Lee, Y. T. (2023). Textbooks Are All You Need II: phi-1.5 Technical Report. arXiv:2309.05463 [cs.CL]. https://doi.org/10.48550/arXiv.2309.05463
Mirchandani, S., Xia, F., Florence, P., Ichter, B., Driess, D., Gonzalez Arenas, M., Rao, K., Sadigh, D., & Zeng, A. (2023). Large Language Models as General Pattern Machines. Stanford University; Google DeepMind; TU Berlin. Retrieved from https://general-pattern-machines.github.io
Wang, Y., Wang, W., Joty, S., & Hoi, S. C. H. (2021). CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation. In Proceedings of the EMNLP 2021. arXiv:2109.00859 [cs.CL]. https://doi.org/10.48550/arXiv.2109.00859